wmma_gemm.hpp Source File#
wmma_gemm.hpp
Go to the documentation of this file.
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
@ wmma_f32_16x16x16_bf16_gfx12
Definition wmma_gemm.hpp:23
@ wmma_f32_16x16x16_bf8f8_gfx12
Definition wmma_gemm.hpp:27
@ wmma_f32_16x16x16_bf8bf8_gfx12
Definition wmma_gemm.hpp:28
@ wmma_f32_16x16x16_f8f8_gfx12
Definition wmma_gemm.hpp:25
@ wmma_f32_16x16x16_f8bf8_gfx12
Definition wmma_gemm.hpp:26
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition utility/sequence.hpp:43
__host__ static __device__ constexpr auto MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA &c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
Definition wmma_gemm.hpp:727
static __device__ constexpr index_t GetRegSizePerWmma()
Definition wmma_gemm.hpp:764
static __device__ constexpr index_t GetWaveSize()
Definition wmma_gemm.hpp:769
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
Definition wmma_gemm.hpp:772
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition wmma_gemm.hpp:829
__host__ static __device__ constexpr auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
Definition wmma_gemm.hpp:868
static __device__ auto GetSwizzledLaneIdLow()
Definition wmma_gemm.hpp:824
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition wmma_gemm.hpp:838
static __device__ CIndex GetBeginOfThreadBlk()
Definition wmma_gemm.hpp:847
__host__ static __device__ constexpr auto MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA &c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
Definition wmma_gemm.hpp:687
static __device__ CIndex3D GetBeginOfThreadBlk3D()
Definition wmma_gemm.hpp:855
static __device__ auto GetLaneIdUnderSubGroup()
Definition wmma_gemm.hpp:820
Definition wmma_gemm.hpp:553
static constexpr auto GetWmma()
__host__ __device__ constexpr WmmaSelector()
Definition wmma_gemm.hpp:639
Definition amd_wmma.hpp:96
Definition amd_wmma.hpp:216
Definition amd_wmma.hpp:72
Definition amd_wmma.hpp:192
Definition amd_wmma.hpp:297
Definition amd_wmma.hpp:50
Definition amd_wmma.hpp:170
Definition amd_wmma.hpp:418
Definition amd_wmma.hpp:394
Definition amd_wmma.hpp:271
Definition amd_wmma.hpp:25
Definition amd_wmma.hpp:149
Definition amd_wmma.hpp:370
Definition amd_wmma.hpp:346
Definition amd_wmma.hpp:319
Definition amd_wmma.hpp:121
Definition amd_wmma.hpp:241
Definition functional2.hpp:33
static constexpr index_t num_src_b_vgprs_per_wave
Definition wmma_gemm.hpp:221
static constexpr index_t wave_size
Definition wmma_gemm.hpp:219
static constexpr index_t num_src_a_vgprs_per_wave
Definition wmma_gemm.hpp:220
static constexpr index_t acc_pack_number
Definition wmma_gemm.hpp:215
static constexpr index_t num_subgroups
Definition wmma_gemm.hpp:224
static constexpr index_t src_a_data_size
Definition wmma_gemm.hpp:212
static constexpr index_t acc_data_size
Definition wmma_gemm.hpp:214
static constexpr index_t k_per_wmma
Definition wmma_gemm.hpp:211
static constexpr index_t src_b_data_size
Definition wmma_gemm.hpp:213
static constexpr index_t m_per_wmma
Definition wmma_gemm.hpp:209
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition wmma_gemm.hpp:232
static constexpr index_t num_thread_per_subgroups
Definition wmma_gemm.hpp:216
static constexpr index_t num_acc_vgprs_per_wave
Definition wmma_gemm.hpp:222
static constexpr index_t n_per_wmma
Definition wmma_gemm.hpp:210
static constexpr index_t num_acc_vgprs_per_wave
Definition wmma_gemm.hpp:186
static constexpr index_t n_per_wmma
Definition wmma_gemm.hpp:174
static constexpr index_t wave_size
Definition wmma_gemm.hpp:183
static constexpr index_t src_a_data_size
Definition wmma_gemm.hpp:176
static constexpr index_t m_per_wmma
Definition wmma_gemm.hpp:173
static constexpr index_t k_per_wmma
Definition wmma_gemm.hpp:175
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition wmma_gemm.hpp:191
static constexpr index_t num_src_b_vgprs_per_wave
Definition wmma_gemm.hpp:185
static constexpr index_t acc_pack_number
Definition wmma_gemm.hpp:179
static constexpr index_t num_subgroups
Definition wmma_gemm.hpp:188
static constexpr index_t num_thread_per_subgroups
Definition wmma_gemm.hpp:180
static constexpr index_t acc_data_size
Definition wmma_gemm.hpp:178
static constexpr index_t src_b_data_size
Definition wmma_gemm.hpp:177
static constexpr index_t num_src_a_vgprs_per_wave
Definition wmma_gemm.hpp:184
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition wmma_gemm.hpp:154
static constexpr index_t k_per_wmma
Definition wmma_gemm.hpp:138
static constexpr index_t num_acc_vgprs_per_wave
Definition wmma_gemm.hpp:149
static constexpr index_t num_src_b_vgprs_per_wave
Definition wmma_gemm.hpp:148
static constexpr index_t m_per_wmma
Definition wmma_gemm.hpp:136
static constexpr index_t num_subgroups
Definition wmma_gemm.hpp:151
static constexpr index_t acc_data_size
Definition wmma_gemm.hpp:141
static constexpr index_t wave_size
Definition wmma_gemm.hpp:146
static constexpr index_t num_src_a_vgprs_per_wave
Definition wmma_gemm.hpp:147
static constexpr index_t n_per_wmma
Definition wmma_gemm.hpp:137
static constexpr index_t acc_pack_number
Definition wmma_gemm.hpp:142
static constexpr index_t num_thread_per_subgroups
Definition wmma_gemm.hpp:143
static constexpr index_t src_a_data_size
Definition wmma_gemm.hpp:139
static constexpr index_t src_b_data_size
Definition wmma_gemm.hpp:140
static constexpr index_t k_per_wmma
Definition wmma_gemm.hpp:341
static constexpr index_t num_thread_per_subgroups
Definition wmma_gemm.hpp:346
static constexpr index_t num_acc_vgprs_per_wave
Definition wmma_gemm.hpp:352
static constexpr index_t n_per_wmma
Definition wmma_gemm.hpp:340
static constexpr index_t m_per_wmma
Definition wmma_gemm.hpp:339
static constexpr index_t acc_data_size
Definition wmma_gemm.hpp:344
static constexpr index_t num_subgroups
Definition wmma_gemm.hpp:353
static constexpr index_t acc_pack_number
Definition wmma_gemm.hpp:345
static constexpr index_t wave_size
Definition wmma_gemm.hpp:349
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition wmma_gemm.hpp:356
static constexpr index_t n_per_wmma
Definition wmma_gemm.hpp:519
static constexpr index_t wave_size
Definition wmma_gemm.hpp:526
static constexpr index_t acc_pack_number
Definition wmma_gemm.hpp:522
static constexpr index_t num_thread_per_subgroups
Definition wmma_gemm.hpp:523
static constexpr index_t k_per_wmma
Definition wmma_gemm.hpp:520
static constexpr index_t m_per_wmma
Definition wmma_gemm.hpp:518
static constexpr index_t num_subgroups
Definition wmma_gemm.hpp:528
static constexpr index_t acc_data_size
Definition wmma_gemm.hpp:521
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition wmma_gemm.hpp:531
static constexpr index_t num_acc_vgprs_per_wave
Definition wmma_gemm.hpp:527
static constexpr index_t k_per_wmma
Definition wmma_gemm.hpp:485
static constexpr index_t num_thread_per_subgroups
Definition wmma_gemm.hpp:488
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition wmma_gemm.hpp:496
static constexpr index_t wave_size
Definition wmma_gemm.hpp:491
static constexpr index_t m_per_wmma
Definition wmma_gemm.hpp:483
static constexpr index_t n_per_wmma
Definition wmma_gemm.hpp:484
static constexpr index_t num_subgroups
Definition wmma_gemm.hpp:493
static constexpr index_t acc_data_size
Definition wmma_gemm.hpp:486
static constexpr index_t num_acc_vgprs_per_wave
Definition wmma_gemm.hpp:492
static constexpr index_t acc_pack_number
Definition wmma_gemm.hpp:487
static constexpr index_t acc_data_size
Definition wmma_gemm.hpp:100
static constexpr index_t src_a_data_size
Definition wmma_gemm.hpp:98
static constexpr index_t num_subgroups
Definition wmma_gemm.hpp:114
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition wmma_gemm.hpp:117
static constexpr index_t num_src_a_vgprs_per_wave
Definition wmma_gemm.hpp:108
static constexpr index_t m_per_wmma
Definition wmma_gemm.hpp:95
static constexpr index_t k_per_wmma
Definition wmma_gemm.hpp:97
static constexpr index_t num_thread_per_subgroups
Definition wmma_gemm.hpp:103
static constexpr index_t acc_pack_number
Definition wmma_gemm.hpp:101
static constexpr index_t num_acc_vgprs_per_wave
Definition wmma_gemm.hpp:112
static constexpr index_t n_per_wmma
Definition wmma_gemm.hpp:96
static constexpr index_t wave_size
Definition wmma_gemm.hpp:106
static constexpr index_t num_src_b_vgprs_per_wave
Definition wmma_gemm.hpp:109
static constexpr index_t src_b_data_size
Definition wmma_gemm.hpp:99
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition wmma_gemm.hpp:323
static constexpr index_t m_per_wmma
Definition wmma_gemm.hpp:301
static constexpr index_t k_per_wmma
Definition wmma_gemm.hpp:303
static constexpr index_t wave_size
Definition wmma_gemm.hpp:313
static constexpr index_t acc_data_size
Definition wmma_gemm.hpp:308
static constexpr index_t n_per_wmma
Definition wmma_gemm.hpp:302
static constexpr index_t num_acc_vgprs_per_wave
Definition wmma_gemm.hpp:319
static constexpr index_t acc_pack_number
Definition wmma_gemm.hpp:309
static constexpr index_t num_subgroups
Definition wmma_gemm.hpp:320
static constexpr index_t num_thread_per_subgroups
Definition wmma_gemm.hpp:310
static constexpr index_t wave_size
Definition wmma_gemm.hpp:456
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition wmma_gemm.hpp:461
static constexpr index_t num_subgroups
Definition wmma_gemm.hpp:458
static constexpr index_t num_thread_per_subgroups
Definition wmma_gemm.hpp:453
static constexpr index_t n_per_wmma
Definition wmma_gemm.hpp:449
static constexpr index_t k_per_wmma
Definition wmma_gemm.hpp:450
static constexpr index_t m_per_wmma
Definition wmma_gemm.hpp:448
static constexpr index_t acc_pack_number
Definition wmma_gemm.hpp:452
static constexpr index_t acc_data_size
Definition wmma_gemm.hpp:451
static constexpr index_t num_acc_vgprs_per_wave
Definition wmma_gemm.hpp:457
static constexpr index_t k_per_wmma
Definition wmma_gemm.hpp:415
static constexpr index_t acc_data_size
Definition wmma_gemm.hpp:416
static constexpr index_t num_subgroups
Definition wmma_gemm.hpp:423
static constexpr index_t acc_pack_number
Definition wmma_gemm.hpp:417
static constexpr index_t num_acc_vgprs_per_wave
Definition wmma_gemm.hpp:422
static constexpr index_t m_per_wmma
Definition wmma_gemm.hpp:413
static constexpr index_t wave_size
Definition wmma_gemm.hpp:421
static constexpr index_t num_thread_per_subgroups
Definition wmma_gemm.hpp:418
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition wmma_gemm.hpp:426
static constexpr index_t n_per_wmma
Definition wmma_gemm.hpp:414
static constexpr index_t num_subgroups
Definition wmma_gemm.hpp:266
static constexpr index_t num_src_b_vgprs_per_wave
Definition wmma_gemm.hpp:263
static constexpr index_t n_per_wmma
Definition wmma_gemm.hpp:252
static constexpr index_t k_per_wmma
Definition wmma_gemm.hpp:253
static constexpr index_t num_thread_per_subgroups
Definition wmma_gemm.hpp:258
static constexpr index_t acc_data_size
Definition wmma_gemm.hpp:256
static constexpr index_t wave_size
Definition wmma_gemm.hpp:261
static constexpr index_t acc_pack_number
Definition wmma_gemm.hpp:257
static constexpr index_t src_a_data_size
Definition wmma_gemm.hpp:254
static constexpr index_t src_b_data_size
Definition wmma_gemm.hpp:255
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition wmma_gemm.hpp:276
static constexpr index_t num_src_a_vgprs_per_wave
Definition wmma_gemm.hpp:262
static constexpr index_t m_per_wmma
Definition wmma_gemm.hpp:251
static constexpr index_t num_acc_vgprs_per_wave
Definition wmma_gemm.hpp:264
static constexpr index_t wave_size
Definition wmma_gemm.hpp:382
static constexpr index_t m_per_wmma
Definition wmma_gemm.hpp:372
static constexpr index_t acc_pack_number
Definition wmma_gemm.hpp:378
static constexpr index_t acc_data_size
Definition wmma_gemm.hpp:377
static constexpr index_t k_per_wmma
Definition wmma_gemm.hpp:374
static constexpr index_t num_subgroups
Definition wmma_gemm.hpp:386
static constexpr index_t num_acc_vgprs_per_wave
Definition wmma_gemm.hpp:385
static constexpr index_t num_thread_per_subgroups
Definition wmma_gemm.hpp:379
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition wmma_gemm.hpp:396
static constexpr index_t n_per_wmma
Definition wmma_gemm.hpp:373
Definition wmma_gemm.hpp:84