gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp Source File

gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp Source File#

Composable Kernel: gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp Source File
gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
17
18namespace ck {
19
20template <typename FloatAB,
21 typename FloatGemmAcc,
22 typename FloatCShuffle,
23 typename FloatC,
24 typename AElementwiseOperation,
25 typename BElementwiseOperation,
26 typename AccElementwiseOperation,
27 typename B1ElementwiseOperation,
28 typename CElementwiseOperation,
29 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
30 typename AGridDesc_AK0_M_AK1,
31 typename BGridDesc_BK0_N_BK1,
32 typename B1GridDesc_BK0_N_BK1,
33 typename CGridDesc_M_N,
34 index_t NumGemmKPrefetchStage,
35 index_t BlockSize,
36 index_t MPerBlock,
37 index_t NPerBlock,
38 index_t KPerBlock,
39 index_t Gemm1NPerBlock,
40 index_t Gemm1KPerBlock,
41 index_t AK1Value,
42 index_t BK1Value,
43 index_t B1K1Value,
44 index_t MPerXdl,
45 index_t NPerXdl,
46 index_t MXdlPerWave,
47 index_t NXdlPerWave,
48 index_t Gemm1NXdlPerWave,
49 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
50 typename ABlockTransferThreadClusterArrangeOrder,
51 typename ABlockTransferSrcAccessOrder,
52 index_t ABlockTransferSrcVectorDim,
53 index_t ABlockTransferSrcScalarPerVector,
54 index_t ABlockTransferDstScalarPerVector_AK1,
55 bool AThreadTransferSrcResetCoordinateAfterRun, // ignored
56 index_t ABlockLdsExtraM,
57 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
58 typename BBlockTransferThreadClusterArrangeOrder,
59 typename BBlockTransferSrcAccessOrder,
60 index_t BBlockTransferSrcVectorDim,
61 index_t BBlockTransferSrcScalarPerVector,
62 index_t BBlockTransferDstScalarPerVector_BK1,
63 bool BThreadTransferSrcResetCoordinateAfterRun, // ignored
64 index_t BBlockLdsExtraN,
65 typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
66 typename B1BlockTransferThreadClusterArrangeOrder,
67 typename B1BlockTransferSrcAccessOrder,
68 index_t B1BlockTransferSrcVectorDim,
69 index_t B1BlockTransferSrcScalarPerVector,
70 index_t B1BlockTransferDstScalarPerVector_BK1,
71 bool B1ThreadTransferSrcResetCoordinateAfterRun,
72 index_t B1BlockLdsExtraN,
73 index_t CShuffleMXdlPerWavePerShuffle,
74 index_t CShuffleNXdlPerWavePerShuffle,
75 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
76 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
77 LoopScheduler LoopSched,
80{
81 static_assert(LoopSched == LoopScheduler::Default,
82 "Non-default loop scheduler is currently not supported");
83
84 static constexpr auto I0 = Number<0>{};
85 static constexpr auto I1 = Number<1>{};
86 static constexpr auto I2 = Number<2>{};
87 static constexpr auto I3 = Number<3>{};
88 static constexpr auto I4 = Number<4>{};
89 static constexpr auto I5 = Number<5>{};
90 static constexpr auto I6 = Number<6>{};
91 static constexpr auto I7 = Number<7>{};
92
93 // K1 should be Number<...>
94 // Gemm0
95 static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
96 static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
97 static constexpr auto AK1 = Number<AK1Value>{};
98 static constexpr auto BK1 = Number<BK1Value>{};
99 // Gemm1
100 static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
101 static constexpr auto B1K1 = Number<B1K1Value>{};
102
104
107
108 template <typename ABlockDesc_AK0_M_AK1>
109 __host__ __device__ static constexpr auto
110 MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
111 {
112 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
113
114 return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, MWaves, MPerXdl>(
115 ABlockDesc_AK0_M_AK1{});
116 }
117
118 template <typename BBlockDesc_BK0_N_BK1>
119 __host__ __device__ static constexpr auto
120 MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
121 {
122 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
123
124 return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, NWaves, NPerXdl>(
125 BBlockDesc_BK0_N_BK1{});
126 }
127
128 template <typename ABlockDesc_AK0_M_AK1>
129 __host__ __device__ static constexpr auto
130 MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
131 {
132 return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, 1, 1>(ABlockDesc_AK0_M_AK1{});
133 }
134
135 template <typename BBlockDesc_BK0_N_BK1>
136 __host__ __device__ static constexpr auto
137 MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
138 {
139 constexpr index_t Gemm1NWaves = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl);
140 return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm1NXdlPerWave, Gemm1NWaves, NPerXdl>(
141 BBlockDesc_BK0_N_BK1{});
142 }
143
144 __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
145 {
146 // A matrix in LDS memory, dst of blockwise copy
150 }
151
152 __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
153 {
154 // B matrix in LDS memory, dst of blockwise copy
158 }
159
160 __host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1()
161 {
162 // B1 matrix in LDS memory, dst of blockwise copy
166 }
167
168 __host__ __device__ static constexpr auto
170 {
171 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
172 constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl);
173
174 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
178 I1,
180
181 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
182 }
183
184 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
185 {
186 const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
188 sizeof(FloatAB);
189 const index_t gemm1_bytes_end =
191 sizeof(FloatAB);
192 const index_t c_block_bytes_end =
193 SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
194
195 return math::max(gemm0_bytes_end, gemm1_bytes_end, c_block_bytes_end);
196 }
197
198 template <
199 InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
200 __device__ static bool constexpr IsValidCompilationParameter()
201 {
202 constexpr bool valid = ck::tensor_operation::device::IsValidGemmCompilationParameter<
203 BlockSize,
204 MPerBlock,
205 NPerBlock,
206 MPerXdl,
207 NPerXdl,
208 MXdlPerWave,
209 NXdlPerWave,
210 FloatC,
211 CGlobalMemoryDataOperation>();
212 if constexpr(!valid)
213 {
214 return false;
215 }
216
217 return true;
218 }
219
220 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
221 template <typename Block2CTileMap>
222 __host__ __device__ static constexpr bool
223 CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
224 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
225 const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
226 const CGridDesc_M_N& c_grid_desc_m_n,
227 const Block2CTileMap& block_2_ctile_map)
228 {
229 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
230 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
231 "Invalid tuning param!");
232
233 const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
234 const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
235 const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
236 const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1);
237
238 if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1)))
239 {
240 return false;
241 }
242
243 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
244 Gemm1N % Gemm1NPerBlock == 0))
245 {
246 return false;
247 }
248
249 // check gemm0 gridwise gemm pipeline
250 const auto num_gemm0_k_loop = K / KPerBlock;
251 if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop))
252 {
253 return false;
254 }
255
256 // check gemm1 gridwise gemm pipeline
257 if(!(NPerBlock % Gemm1KPerBlock == 0))
258 {
259 return false;
260 }
261
262 const auto num_gemm1_k_inner_loop = NPerBlock / Gemm1KPerBlock;
263 if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop))
264 {
265 return false;
266 }
267
268 if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
269 {
270 return false;
271 }
272
273 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
274 return true;
275 }
276
277 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
278 {
279 const index_t num_loop = K / KPerBlock;
280
281 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
282 }
283
284 __host__ __device__ static constexpr auto
285 MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
286 {
287 const auto M = c_grid_desc_m_n.GetLength(I0);
288 const auto N = c_grid_desc_m_n.GetLength(I1);
289
290 const auto MBlock = M / MPerBlock;
291 const auto NBlock = N / Gemm1NPerBlock;
292
293 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
294 c_grid_desc_m_n,
299
300 return c_grid_desc_mblock_mperblock_nblock_nperblock;
301 }
302
303 // return block_id to C matrix tile idx (m0, n0) mapping
304 __host__ __device__ static constexpr auto
305 MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
306 {
308 c_grid_desc_m_n);
309 }
310
313 CGridDesc_M_N{}))>;
314
316 remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
317
319 {
320 // LDS allocation for A and B: be careful of alignment
321 static constexpr auto a_block_desc_ak0_m_ak1 =
323 static constexpr auto b_block_desc_bk0_n_bk1 =
325 static constexpr auto b1_block_desc_bk0_n_bk1 =
327
328 static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1);
329
331 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
333 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
335 b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
336
337 static constexpr auto a_block_space_offset = 0;
339 static constexpr auto b1_block_space_offset = 0;
340
341 // LDS allocation for C shuffle in LDS
344 static constexpr auto c_block_space_size =
346 };
347
348 template <bool HasMainKBlockLoop, typename Block2CTileMap>
349 __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
350 const FloatAB* __restrict__ p_b_grid,
351 const FloatAB* __restrict__ p_b1_grid,
352 FloatC* __restrict__ p_c_grid,
353 void* __restrict__ p_shared,
354 const AElementwiseOperation& a_element_op,
355 const BElementwiseOperation& b_element_op,
356 const AccElementwiseOperation& acc_element_op,
357 const B1ElementwiseOperation& b1_element_op,
358 const CElementwiseOperation& c_element_op,
359 const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
360 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
361 const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
363 c_grid_desc_mblock_mperblock_nblock_nperblock,
364 const Block2CTileMap& block_2_ctile_map)
365 {
366 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
367 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
368 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
369 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
370 const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
371 p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize());
373 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
374
375 // divide block work by [M, N]
376 const auto block_work_idx =
377 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
378
379 if(!block_2_ctile_map.ValidCTileIndex(
380 block_work_idx,
381 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
382 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
383 {
384 return;
385 }
386
387 // HACK: this force m/n_block_data_idx_on_grid into SGPR
388 const index_t m_block_data_idx_on_grid =
389 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
390
391 const index_t n_block_data_idx_on_grid =
392 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
393
394 // A matrix in LDS memory, dst of blockwise copy
395 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
396
397 // B matrix in LDS memory, dst of blockwise copy
398 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
399
400 //
401 // set up Gemm0
402 //
403
404 // A matrix blockwise copy
405 auto a_blockwise_copy =
407 AElementwiseOperation,
411 ABlockTransferThreadClusterLengths_AK0_M_AK1,
412 ABlockTransferThreadClusterArrangeOrder,
413 FloatAB,
414 FloatAB,
415 decltype(a_grid_desc_ak0_m_ak1),
416 decltype(a_block_desc_ak0_m_ak1),
417 ABlockTransferSrcAccessOrder,
419 ABlockTransferSrcVectorDim,
420 2,
421 ABlockTransferSrcScalarPerVector,
422 ABlockTransferDstScalarPerVector_AK1,
423 1,
424 1,
425 true, // SrcResetCoord
426 true, // DstResetCoord
427 NumGemmKPrefetchStage>(
428 a_grid_desc_ak0_m_ak1,
429 make_multi_index(0, m_block_data_idx_on_grid, 0),
430 a_element_op,
431 a_block_desc_ak0_m_ak1,
432 make_multi_index(0, 0, 0),
434
435 // B matrix blockwise copy
436 auto b_blockwise_copy =
438 BElementwiseOperation,
442 BBlockTransferThreadClusterLengths_BK0_N_BK1,
443 BBlockTransferThreadClusterArrangeOrder,
444 FloatAB,
445 FloatAB,
446 decltype(b_grid_desc_bk0_n_bk1),
447 decltype(b_block_desc_bk0_n_bk1),
448 BBlockTransferSrcAccessOrder,
450 BBlockTransferSrcVectorDim,
451 2,
452 BBlockTransferSrcScalarPerVector,
453 BBlockTransferDstScalarPerVector_BK1,
454 1,
455 1,
456 true, // SrcResetCoord
457 true, // DstResetCoord
458 NumGemmKPrefetchStage>(
459 b_grid_desc_bk0_n_bk1,
460 make_multi_index(0, 0, 0), // will loop over GemmN dimension
461 b_element_op,
462 b_block_desc_bk0_n_bk1,
463 make_multi_index(0, 0, 0),
465
466 // Fused Gemm+Gemm pipeline
467 // for n in N0:
468 // for k in K0:
469 // acc[m][n] += A[m][k] * B0[k][n]
470 // acc1[m][o] += acc[m][n] * B1[n][o]
471
472 // sanity check
473 constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
474 constexpr bool is_single_rate_mfma =
476 lcm_AK1_BK1 <= 4) ||
477 (is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
479 lcm_AK1_BK1 < 32))
480 ? true
481 : false;
482 constexpr auto is_scale_mfma = false;
483 constexpr index_t KPack = math::max(
484 lcm_AK1_BK1,
486 selected_mfma.k_per_blk);
487
488 auto blockwise_gemm = BlockwiseGemmXdlops_v2<
489 BlockSize,
490 FloatAB,
491 FloatGemmAcc,
492 decltype(a_block_desc_ak0_m_ak1),
493 decltype(b_block_desc_bk0_n_bk1),
494 decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
495 decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
496 MPerBlock,
497 NPerBlock,
498 KPerBlock,
499 MPerXdl,
500 NPerXdl,
501 MXdlPerWave,
502 NXdlPerWave,
503 KPack,
504 true>{}; // TransposeC
505
506 auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer();
507
508 // LDS allocation for A and B: be careful of alignment
510 static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_offset,
511 a_block_desc_ak0_m_ak1.GetElementSpaceSize());
512
514 static_cast<FloatAB*>(p_shared) + SharedMemTrait::b_block_space_offset,
515 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
516
517 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
518 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
519 const auto a_block_reset_copy_step =
520 make_multi_index(-a_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0);
521 const auto b_block_reset_copy_step =
522 make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0);
523
524 // gridwise GEMM pipeline
525 // Only supports LoopScheduler::Default
526 const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_Selector<PipelineVer,
527 NumGemmKPrefetchStage,
529
530 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
531 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
532 KPerBlock);
533
534 //
535 // set up Gemm1
536 //
537
538 // Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
539 constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
540 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
541
542 constexpr auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
543 constexpr auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
544 constexpr auto m1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
545 constexpr auto n1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
546 constexpr auto m2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
547 constexpr auto n2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
548 constexpr auto n3 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
549 constexpr auto n4 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
550
551 constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0);
552
553 // acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1
554 // n0_n1_n2_n3 -> k0
555 // m0_m1_m2 -> m
556 // n4 -> k1
557 // NOTE: had to use merge_v3 or will spit out compilation errors
558 constexpr auto acc_thread_desc_k0_m_k1 = transform_tensor_descriptor(
559 acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
565
566 // A1 matrix in AccVGPR
567 // N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
568 constexpr auto AccN3 =
569 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6);
570
571 constexpr auto A1ThreadSlice_K0_M_K1 =
573
574 constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0];
575 constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1];
576 constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2];
577
578#if defined(__gfx11__)
579 constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor_packed(
580 make_tuple(A1ThreadSliceK0, A1ThreadSliceM, Number<A1ThreadSliceK1 * 2>{}));
582 FloatGemmAcc,
583 FloatAB,
584 decltype(acc_thread_desc_k0_m_k1),
585 decltype(a1_thread_desc_k0_m_k1),
586 decltype(acc_element_op),
589 2,
590 n4,
591 0x76543210,
592 0xfedcba98,
593 true>{make_tuple(0, 0, 0)};
594#else
595 constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor(
596 A1ThreadSlice_K0_M_K1,
597 make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1));
598 // A1 matrix blockwise copy
599 auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
600 FloatGemmAcc,
601 FloatAB,
602 decltype(acc_thread_desc_k0_m_k1),
603 decltype(a1_thread_desc_k0_m_k1),
604 decltype(acc_element_op),
607 2,
608 n4>{acc_element_op};
609#endif
610 // B1 matrix in LDS memory, dst of blockwise copy
611 constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
612
613 // B1 matrix blockwise copy
614 auto b1_blockwise_copy =
616 BElementwiseOperation,
620 B1BlockTransferThreadClusterLengths_BK0_N_BK1,
621 B1BlockTransferThreadClusterArrangeOrder,
622 FloatAB,
623 FloatAB,
624 decltype(b1_grid_desc_bk0_n_bk1),
625 decltype(b1_block_desc_bk0_n_bk1),
626 B1BlockTransferSrcAccessOrder,
628 B1BlockTransferSrcVectorDim,
629 2,
630 B1BlockTransferSrcScalarPerVector,
631 B1BlockTransferDstScalarPerVector_BK1,
632 1,
633 1,
634 B1ThreadTransferSrcResetCoordinateAfterRun,
635 true, // DstResetCoord
636 NumGemmKPrefetchStage>(
637 b1_grid_desc_bk0_n_bk1,
638 make_multi_index(0, n_block_data_idx_on_grid, 0),
639 b1_element_op,
640 b1_block_desc_bk0_n_bk1,
641 make_multi_index(0, 0, 0),
643
645 a1_thread_desc_k0_m_k1.GetElementSpaceSize());
646
647 // reuse LDS space for gemm0's b_block_buf
649 static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset,
650 b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
651
652 // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
653 // selected_mfma.k_per_blk <= Gemm1KPack
654 //
655 // Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
656 // multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
657 // Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
658 // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
659 // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
660 // therefore we may just as well assign Gemm1KPack = group_size
661#if defined(__gfx11__)
662 constexpr index_t Gemm1KPack =
664 static_assert(
666#else
667 constexpr index_t Gemm1KPack =
669#endif
670 auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
671 BlockSize,
672 FloatAB,
673 FloatGemmAcc,
674 decltype(a1_thread_desc_k0_m_k1),
675 decltype(b1_block_desc_bk0_n_bk1),
676 decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a1_thread_desc_k0_m_k1)),
677 decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b1_block_desc_bk0_n_bk1)),
678 MPerBlock,
679 Gemm1NPerBlock,
680 Gemm1KPerBlock,
681 MPerXdl,
682 NPerXdl,
683 MXdlPerWave,
684 Gemm1NXdlPerWave,
685 Gemm1KPack,
686 false, // TransposeC
687 Gemm1KPack, // AMmaKStride
688 Gemm1KPack *
690 // BMmaKStride
691 make_tuple(0, 0, 0, 0)}; // A_origin
692
693 auto c_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
694
695 const index_t num_gemm1_k_block_outer_loop =
696 b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock;
697 constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock;
698
699 // Initialize C
700 c_thread_buf.Clear();
701
702 // gemm1 K loop
703 index_t gemm1_k_block_outer_index = 0;
704 do
705 {
706 // gemm0
707 gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
708 a_block_desc_ak0_m_ak1,
709 a_blockwise_copy,
710 a_grid_buf,
711 a_block_buf,
712 a_block_slice_copy_step,
713 b_grid_desc_bk0_n_bk1,
714 b_block_desc_bk0_n_bk1,
715 b_blockwise_copy,
716 b_grid_buf,
717 b_block_buf,
718 b_block_slice_copy_step,
719 blockwise_gemm,
720 acc_thread_buf,
721 num_k_block_main_loop);
722 // gemm1
723 {
724 // TODO: explore using dynamic buffer for a1 thread buffer
725 // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
726 // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
727 // the A1 source buffer is static buffer holding the output of first GEMM and
728 // requires constexpr offset by design. Therefore, we pass tensor coordinate offset
729 // explicitly in Run() below.
730
731 // preload data into LDS
732 b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
733
734 b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
735 b1_block_slice_copy_step);
736
737 block_sync_lds(); // wait for gemm0 LDS read
738
739 b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
740
741 // main body
742 if constexpr(num_gemm1_k_block_inner_loop > 1)
743 {
744 static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
745 a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1,
747 acc_thread_buf,
748 a1_thread_desc_k0_m_k1,
749 make_tuple(I0, I0, I0),
750 a1_thread_buf);
751
752 b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
753
755
756 gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, c_thread_buf);
757
759
760 b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
761 b1_block_slice_copy_step);
762
763 b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
764 });
765 }
766 // tail
767 {
768 a1_blockwise_copy.Run(
769 acc_thread_desc_k0_m_k1,
771 Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0),
772 acc_thread_buf,
773 a1_thread_desc_k0_m_k1,
774 make_tuple(I0, I0, I0),
775 a1_thread_buf);
776
778
779 gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, c_thread_buf);
780 }
781 } // end gemm1
782
783 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1,
784 a_block_reset_copy_step); // rewind K
785 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_bk0_n_bk1,
786 b_block_reset_copy_step); // rewind K and step N
787
788 block_sync_lds(); // wait for gemm1 LDS read
789 } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
790
791 // shuffle C and write out
792 {
793 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
794 Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
795 "wrong!");
796
797 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
798 constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl);
799
800 // TODO: hacky, fix it!
801 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
802 gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
803
804 // TODO: hacky, fix it!
805 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
806 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
807 gemm1_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
808
809 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
810 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
811 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
812 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
813 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
814 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
815 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
816 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
817
818 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
820
821 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
822 static_cast<FloatCShuffle*>(p_shared),
823 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
824
825 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
826 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
830 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
831 M1, // M1 = MWave
832 M2, // M2 * M3 * M4 = MPerXdl
833 M3,
834 M4)),
837 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
838 N1, // N1 = NWave
839 N2))), // N2 = NPerXdl
843
844 // calculate origin of thread output tensor on global memory
845 // blockwise GEMM c matrix starting index
846 const auto c_thread_mtx_on_block =
847 gemm1_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
848
849 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
850 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
851
852 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
854 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
857
858 const auto m_thread_data_on_block_idx =
859 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
860 make_multi_index(m_thread_data_on_block));
861
862 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
867
868 const auto n_thread_data_on_block_idx =
869 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
870 make_multi_index(n_thread_data_on_block));
871
872 // shuffle: threadwise copy C from VGPR to LDS
873 auto c_thread_copy_vgpr_to_lds =
875 FloatCShuffle,
876 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
877 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
879 Sequence<CShuffleMXdlPerWavePerShuffle,
880 CShuffleNXdlPerWavePerShuffle,
881 I1,
882 I1,
883 M2,
884 I1,
885 M4,
886 I1>,
888 7,
889 1,
891 1,
892 true>{
893 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
895 0,
896 m_thread_data_on_block_idx[I1],
897 n_thread_data_on_block_idx[I1],
898 m_thread_data_on_block_idx[I2],
899 m_thread_data_on_block_idx[I3],
900 m_thread_data_on_block_idx[I4],
901 n_thread_data_on_block_idx[I2]),
903
904 // shuffle: blockwise copy C from LDS to global
905 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
906 ThisThreadBlock, // ThreadGroup
907 CElementwiseOperation, // ElementwiseOperation,
908 CGlobalMemoryDataOperation, // DstInMemOp,
909 Sequence<1,
910 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
911 1,
912 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
913 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
914 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
915 FloatCShuffle, // typename SrcData,
916 FloatC, // typename DstData,
917 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
918 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
919 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
920 3, // index_t VectorDim,
921 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
922 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
923 false> // bool ThreadTransferDstResetCoordinateAfterRun>
924 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
925 make_multi_index(0, 0, 0, 0),
926 c_grid_desc_mblock_mperblock_nblock_nperblock,
927 make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
928 c_element_op};
929
930 // space filling curve for threadwise C in VGPR
931 constexpr auto sfc_c_vgpr =
934 Sequence<CShuffleMXdlPerWavePerShuffle,
935 CShuffleNXdlPerWavePerShuffle,
936 1,
937 1,
938 M2,
939 1,
940 M4,
941 1>>{};
942
943 // space filling curve for shuffled blockwise C in global mem
944 constexpr auto sfc_c_global =
947 Sequence<1,
948 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
949 1,
950 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
951
952 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
953
954 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
955
956 static_for<0, num_access, 1>{}([&](auto access_id) {
957 // make sure it's safe to write to LDS
959
960 // each thread write its data from VGPR to LDS
961 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
962 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
963 c_thread_buf,
964 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
965 c_shuffle_block_buf);
966
967 // make sure it's safe to read from LDS
969
970 // each block copy its data from LDS to global
971 c_shuffle_block_copy_lds_to_global.Run(
972 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
973 c_shuffle_block_buf,
974 c_grid_desc_mblock_mperblock_nblock_nperblock,
975 c_grid_buf);
976
977 if constexpr(access_id < num_access - 1)
978 {
979 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
980
981 // move on C
982 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
983 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
984 }
985 });
986 }
987 }
988};
989
990} // namespace ck
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition block_to_ctile_map.hpp:261
Blockwise gemm.
Definition blockwise_gemm_xdlops.hpp:690
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:319
static constexpr auto c_block_space_size
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:344
ck::GridwiseBatchedGemmGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >< math::max(MXdlPerWave64, 1)>::SharedMemTrait::b1_block_space_size_aligned
static constexpr auto b1_block_space_size_aligned
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:334
static constexpr auto b1_block_desc_bk0_n_bk1
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:325
static constexpr auto b_block_space_offset
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:338
static constexpr auto b_block_desc_bk0_n_bk1
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:323
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:342
static constexpr auto b1_block_space_offset
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:339
static constexpr auto a_block_space_offset
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:337
static constexpr auto a_block_desc_ak0_m_ak1
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:321
static constexpr auto max_lds_align
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:328
static constexpr auto b_block_space_size_aligned
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:332
static constexpr auto a_block_space_size_aligned
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:330
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:80
__host__ static __device__ constexpr auto MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:130
static constexpr auto AK1
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:97
__host__ static __device__ constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:152
static constexpr auto I7
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:91
static constexpr auto I3
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:87
static constexpr auto B1K1
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:101
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:305
static constexpr auto AK0
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:95
__host__ static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:184
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const ADataType *__restrict__ p_b_grid, const ADataType *__restrict__ p_b1_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const AccElementwiseOperation &acc_element_op, const B1ElementwiseOperation &b1_element_op, const CElementwiseOperation &c_element_op, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 &b1_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:349
__host__ static __device__ constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:169
ck::GridwiseBatchedGemmGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched >< math::max(MXdlPerWave64, 1)>::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:311
__host__ static __device__ constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:160
__host__ static __device__ constexpr auto MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:120
static constexpr auto I0
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:84
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:277
static constexpr auto B1K0
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:100
static constexpr auto I4
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:88
static constexpr auto I1
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:85
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 &b1_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:223
__host__ static __device__ constexpr auto MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:110
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:285
static constexpr auto BK0
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:96
__host__ static __device__ constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:144
static __device__ bool constexpr IsValidCompilationParameter()
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:200
__host__ static __device__ constexpr auto MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:137
static constexpr auto I6
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:90
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:103
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))> DefaultBlock2CTileMap
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:315
static constexpr auto I5
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:89
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVersion::v1, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:105
static constexpr auto I2
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:86
static constexpr auto BK1
Definition gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp:98
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
static constexpr auto selected_mfma
Definition xdlops_gemm.hpp:1757
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition threadwise_tensor_slice_transfer.hpp:1877
Threadwise data transfer.
Definition threadwise_tensor_slice_transfer.hpp:1720
Definition threadwise_tensor_slice_transfer.hpp:39
Definition xdlops_gemm.hpp:1821
static constexpr auto K0PerXdlops
Definition xdlops_gemm.hpp:2201
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340