gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp Source File

gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp Source File#

Composable Kernel: gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp Source File
gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
18
19namespace ck {
20
25template <typename FloatAB,
26 typename FloatGemmAcc,
27 typename FloatCShuffle,
28 typename FloatC,
29 typename AElementwiseOperation,
30 typename BElementwiseOperation,
31 typename AccElementwiseOperation,
32 typename B1ElementwiseOperation,
33 typename CElementwiseOperation,
34 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
35 typename AGridDesc_AK0_M_AK1,
36 typename BGridDesc_BK0_N_BK1,
37 typename B1GridDesc_BK0_N_BK1,
38 typename CGridDesc_M_N,
39 index_t NumGemmKPrefetchStage,
40 index_t BlockSize,
41 index_t MPerBlock,
42 index_t NPerBlock,
43 index_t KPerBlock,
44 index_t Gemm1NPerBlock,
45 index_t Gemm1KPerBlock,
46 index_t AK1Value,
47 index_t BK1Value,
48 index_t B1K1Value,
49 index_t MPerXdl,
50 index_t NPerXdl,
51 index_t MXdlPerWave,
52 index_t NXdlPerWave,
53 index_t Gemm1NXdlPerWave,
54 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
55 typename ABlockTransferThreadClusterArrangeOrder,
56 typename ABlockTransferSrcAccessOrder,
57 index_t ABlockTransferSrcVectorDim,
58 index_t ABlockTransferSrcScalarPerVector,
59 index_t ABlockTransferDstScalarPerVector_AK1,
60 bool AThreadTransferSrcResetCoordinateAfterRun, // ignored
61 index_t ABlockLdsExtraM,
62 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
63 typename BBlockTransferThreadClusterArrangeOrder,
64 typename BBlockTransferSrcAccessOrder,
65 index_t BBlockTransferSrcVectorDim,
66 index_t BBlockTransferSrcScalarPerVector,
67 index_t BBlockTransferDstScalarPerVector_BK1,
68 bool BThreadTransferSrcResetCoordinateAfterRun, // ignored
69 index_t BBlockLdsExtraN,
70 typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
71 typename B1BlockTransferThreadClusterArrangeOrder,
72 typename B1BlockTransferSrcAccessOrder,
73 index_t B1BlockTransferSrcVectorDim,
74 index_t B1BlockTransferSrcScalarPerVector,
75 index_t B1BlockTransferDstScalarPerVector_BK1,
76 bool B1ThreadTransferSrcResetCoordinateAfterRun,
77 index_t B1BlockLdsExtraN,
78 index_t CShuffleMXdlPerWavePerShuffle,
79 index_t CShuffleNXdlPerWavePerShuffle,
80 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
81 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
82 LoopScheduler LoopSched,
83 bool PadN,
87{
88 static_assert(LoopSched == LoopScheduler::Default,
89 "Non-default loop scheduler is currently not supported");
90
91 static constexpr auto I0 = Number<0>{};
92 static constexpr auto I1 = Number<1>{};
93 static constexpr auto I2 = Number<2>{};
94 static constexpr auto I3 = Number<3>{};
95 static constexpr auto I4 = Number<4>{};
96 static constexpr auto I5 = Number<5>{};
97 static constexpr auto I6 = Number<6>{};
98 static constexpr auto I7 = Number<7>{};
99
100 // K1 should be Number<...>
101 // Gemm0
102 static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
103 static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
104 static constexpr auto AK1 = Number<AK1Value>{};
105 static constexpr auto BK1 = Number<BK1Value>{};
106
107 static constexpr auto Gemm0MWaves = MPerBlock / (MPerXdl * MXdlPerWave);
108 static constexpr auto Gemm0NWaves = NPerBlock / (NPerXdl * NXdlPerWave);
109
110 // Gemm1
111 static constexpr auto B1K0 = Number<Gemm1KPerBlock / B1K1Value>{};
112 static constexpr auto B1K1 = Number<B1K1Value>{};
113
115
118
119 template <typename ABlockDesc_AK0_M_AK1>
120 __host__ __device__ static constexpr auto
121 MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
122 {
123 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
124
125 return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, MWaves, MPerXdl>(
126 ABlockDesc_AK0_M_AK1{});
127 }
128
129 template <typename BBlockDesc_BK0_N_BK1>
130 __host__ __device__ static constexpr auto
131 MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
132 {
133 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
134
135 return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<NXdlPerWave, NWaves, NPerXdl>(
136 BBlockDesc_BK0_N_BK1{});
137 }
138
139 template <typename ABlockDesc_AK0_M_AK1>
140 __host__ __device__ static constexpr auto
141 MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
142 {
143 return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<MXdlPerWave, 1, 1>(ABlockDesc_AK0_M_AK1{});
144 }
145
146 template <typename BBlockDesc_BK0_N_BK1>
147 __host__ __device__ static constexpr auto
148 MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
149 {
150 constexpr index_t Gemm1NWaves = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl);
151 return MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K<Gemm1NXdlPerWave, Gemm1NWaves, NPerXdl>(
152 BBlockDesc_BK0_N_BK1{});
153 }
154
155 __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
156 {
157 // A matrix in LDS memory, dst of blockwise copy
161 }
162
163 __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
164 {
165 // B matrix in LDS memory, dst of blockwise copy
169 }
170
171 __host__ __device__ static constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1()
172 {
173 // B1 matrix in LDS memory, dst of blockwise copy
177 }
178
179 __host__ __device__ static constexpr auto
181 {
182 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
183 constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl);
184
185 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
189 I1,
191
192 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
193 }
194
195 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
196 {
197 const index_t gemm0_bytes_end = (SharedMemTrait::a_block_space_size_aligned +
199 sizeof(FloatAB);
200 const index_t gemm1_bytes_end =
202 sizeof(FloatAB);
203 const index_t softmax_bytes_end = (SharedMemTrait::reduction_space_offset +
205 sizeof(FloatGemmAcc);
206 const index_t c_block_bytes_end =
207 SharedMemTrait::c_block_space_size * sizeof(FloatCShuffle);
208
209 return math::max(gemm0_bytes_end, gemm1_bytes_end, softmax_bytes_end, c_block_bytes_end);
210 }
211
212 template <
213 InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
214 __device__ static bool constexpr IsValidCompilationParameter()
215 {
216 return ck::tensor_operation::device::IsValidGemmCompilationParameter<
217 BlockSize,
218 MPerBlock,
219 NPerBlock,
220 MPerXdl,
221 NPerXdl,
222 MXdlPerWave,
223 NXdlPerWave,
224 FloatC,
225 CGlobalMemoryDataOperation>();
226 }
227
228 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
229 template <typename Block2CTileMap>
230 __host__ __device__ static constexpr bool
231 CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
232 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
233 const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
234 const CGridDesc_M_N& c_grid_desc_m_n,
235 const Block2CTileMap& block_2_ctile_map)
236 {
237 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
238 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
239 "Invalid tuning param!");
240
241 const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
242 const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
243 const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
244 const auto Gemm1N = b1_grid_desc_bk0_n_bk1.GetLength(I1);
245
246 if(!(M == c_grid_desc_m_n.GetLength(I0) && Gemm1N == c_grid_desc_m_n.GetLength(I1)))
247 {
248 return false;
249 }
250
251 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0 &&
252 Gemm1N % Gemm1NPerBlock == 0))
253 {
254 return false;
255 }
256
257 // check gemm0 gridwise gemm pipeline
258 const auto num_gemm0_k_loop = K / KPerBlock;
259 if(!GridwiseGemmPipe::IsSupported(num_gemm0_k_loop))
260 {
261 return false;
262 }
263
264 // check gemm1 gridwise gemm pipeline
265 if(!(NPerBlock % Gemm1KPerBlock == 0))
266 {
267 return false;
268 }
269
270 const auto num_gemm1_k_inner_loop = NPerBlock / Gemm1KPerBlock;
271 if(!GridwiseGemmPipe::IsSupported(num_gemm1_k_inner_loop))
272 {
273 return false;
274 }
275
276 if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
277 {
278 return false;
279 }
280
281 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
282 return true;
283 }
284
285 __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
286 {
287 const index_t num_loop = K / KPerBlock;
288
289 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
290 }
291
292 __host__ __device__ static constexpr auto
293 MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
294 {
295 const auto M = c_grid_desc_m_n.GetLength(I0);
296 const auto N = c_grid_desc_m_n.GetLength(I1);
297
298 const auto MBlock = M / MPerBlock;
299 const auto NBlock = N / Gemm1NPerBlock;
300
301 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
302 c_grid_desc_m_n,
307
308 return c_grid_desc_mblock_mperblock_nblock_nperblock;
309 }
310
311 // return block_id to C matrix tile idx (m0, n0) mapping
312 __host__ __device__ static constexpr auto
313 MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
314 {
316 c_grid_desc_m_n);
317 }
318
321 CGridDesc_M_N{}))>;
322
324 remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
325
327 {
328 // LDS allocation for A and B: be careful of alignment
329 static constexpr auto a_block_desc_ak0_m_ak1 =
331 static constexpr auto b_block_desc_bk0_n_bk1 =
333 static constexpr auto b1_block_desc_bk0_n_bk1 =
335
336 static constexpr auto max_lds_align = math::lcm(math::lcm(AK1, BK1), B1K1);
337
339 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
341 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
343 b1_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
344
345 static constexpr auto a_block_space_offset = 0;
347 static constexpr auto b1_block_space_offset = 0;
348
349 // LDS allocation for reduction
352
353 static constexpr auto reduction_space_offset = 0;
354
355 // LDS allocation for C shuffle in LDS
358 static constexpr auto c_block_space_size =
360 };
361
362 template <bool HasMainKBlockLoop, typename Block2CTileMap, typename C0MatrixMask>
363 __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
364 const FloatAB* __restrict__ p_b_grid,
365 const FloatAB* __restrict__ p_b1_grid,
366 FloatC* __restrict__ p_c_grid,
367 void* __restrict__ p_shared,
368 const AElementwiseOperation& a_element_op,
369 const BElementwiseOperation& b_element_op,
370 const AccElementwiseOperation& acc_element_op,
371 const B1ElementwiseOperation& b1_element_op,
372 const CElementwiseOperation& c_element_op,
373 const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
374 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
375 const B1GridDesc_BK0_N_BK1& b1_grid_desc_bk0_n_bk1,
377 c_grid_desc_mblock_mperblock_nblock_nperblock,
378 const Block2CTileMap& block_2_ctile_map,
379 const C0MatrixMask& c0_matrix_mask)
380 {
381 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
382 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
383 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
384 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
385 const auto b1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
386 p_b1_grid, b1_grid_desc_bk0_n_bk1.GetElementSpaceSize());
388 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
389
390 // divide block work by [M, N]
391 const auto block_work_idx =
392 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
393
394 if(!block_2_ctile_map.ValidCTileIndex(
395 block_work_idx,
396 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
397 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
398 {
399 return;
400 }
401
402 // HACK: this force m/gemm1_n_block_data_idx_on_grid into SGPR
403 const index_t m_block_data_idx_on_grid =
404 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
405
406 const index_t gemm1_n_block_data_idx_on_grid =
407 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * Gemm1NPerBlock);
408
409 // A matrix in LDS memory, dst of blockwise copy
410 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
411
412 // B matrix in LDS memory, dst of blockwise copy
413 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
414
415 //
416 // set up Gemm0
417 //
418
419 // A matrix blockwise copy
420 auto a_blockwise_copy =
422 AElementwiseOperation,
426 ABlockTransferThreadClusterLengths_AK0_M_AK1,
427 ABlockTransferThreadClusterArrangeOrder,
428 FloatAB,
429 FloatAB,
430 decltype(a_grid_desc_ak0_m_ak1),
431 decltype(a_block_desc_ak0_m_ak1),
432 ABlockTransferSrcAccessOrder,
434 ABlockTransferSrcVectorDim,
435 2,
436 ABlockTransferSrcScalarPerVector,
437 ABlockTransferDstScalarPerVector_AK1,
438 1,
439 1,
440 true, // SrcResetCoord
441 true, // DstResetCoord
442 NumGemmKPrefetchStage>(
443 a_grid_desc_ak0_m_ak1,
444 make_multi_index(0, m_block_data_idx_on_grid, 0),
445 a_element_op,
446 a_block_desc_ak0_m_ak1,
447 make_multi_index(0, 0, 0),
449
450 // B matrix blockwise copy
451 auto b_blockwise_copy =
453 BElementwiseOperation,
457 BBlockTransferThreadClusterLengths_BK0_N_BK1,
458 BBlockTransferThreadClusterArrangeOrder,
459 FloatAB,
460 FloatAB,
461 decltype(b_grid_desc_bk0_n_bk1),
462 decltype(b_block_desc_bk0_n_bk1),
463 BBlockTransferSrcAccessOrder,
465 BBlockTransferSrcVectorDim,
466 2,
467 BBlockTransferSrcScalarPerVector,
468 BBlockTransferDstScalarPerVector_BK1,
469 1,
470 1,
471 true, // SrcResetCoord
472 true, // DstResetCoord
473 NumGemmKPrefetchStage>(
474 b_grid_desc_bk0_n_bk1,
475 make_multi_index(0, 0, 0), // will loop over GemmN dimension
476 b_element_op,
477 b_block_desc_bk0_n_bk1,
478 make_multi_index(0, 0, 0),
480
481 // Fused Gemm+Gemm pipeline
482 // for n in N0:
483 // for k in K0:
484 // acc[m][n] += A[m][k] * B0[k][n]
485 // acc1[m][o] += acc[m][n] * B1[n][o]
486
487 // sanity check
488 constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
489 constexpr bool is_single_rate_mfma =
491 lcm_AK1_BK1 <= 4) ||
492 (is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
494 lcm_AK1_BK1 < 32))
495 ? true
496 : false;
497 constexpr auto is_scale_mfma = false;
498 constexpr index_t KPack = math::max(
499 lcm_AK1_BK1,
501 selected_mfma.k_per_blk);
502
503 auto blockwise_gemm = BlockwiseGemmXdlops_v2<
504 BlockSize,
505 FloatAB,
506 FloatGemmAcc,
507 decltype(a_block_desc_ak0_m_ak1),
508 decltype(b_block_desc_bk0_n_bk1),
509 decltype(MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
510 decltype(MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
511 MPerBlock,
512 NPerBlock,
513 KPerBlock,
514 MPerXdl,
515 NPerXdl,
516 MXdlPerWave,
517 NXdlPerWave,
518 KPack,
519 true>{}; // TransposeC
520
521 auto acc_thread_buf = blockwise_gemm.GetCThreadBuffer();
522
523 // LDS allocation for A and B: be careful of alignment
525 static_cast<FloatAB*>(p_shared) + SharedMemTrait::a_block_space_offset,
526 a_block_desc_ak0_m_ak1.GetElementSpaceSize());
527
529 static_cast<FloatAB*>(p_shared) + SharedMemTrait::b_block_space_offset,
530 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
531
532 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
533 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
534 const auto a_block_reset_copy_step =
535 make_multi_index(-a_grid_desc_ak0_m_ak1.GetLength(I0), 0, 0);
536 const auto b_block_reset_copy_step =
537 make_multi_index(-b_grid_desc_bk0_n_bk1.GetLength(I0), NPerBlock, 0);
538
539 // gridwise GEMM pipeline
540 // Only supports LoopScheduler::Default
541 const auto gridwise_gemm_pipeline = GridwiseGemmPipeline_Selector<PipelineVer,
542 NumGemmKPrefetchStage,
544
545 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
546 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
547 KPerBlock);
548
549 //
550 // set up Gemm1
551 //
552
553 // Acc matrix threadwise copy: AccVGPR to VGPR and downcast to XDL input data type
554 constexpr auto acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
555 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
556
557 constexpr auto m0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
558 constexpr auto n0 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
559 constexpr auto m1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
560 constexpr auto n1 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
561 constexpr auto m2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
562 constexpr auto n2 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
563 constexpr auto n3 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
564 constexpr auto n4 = acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
565
566 constexpr auto b1_block_slice_copy_step = make_multi_index(Gemm1KPerBlock / B1K1, 0, 0);
567
568 // acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 to acc_thread_desc_k0_m_k1
569 // n0_n1_n2_n3 -> k0
570 // m0_m1_m2 -> m
571 // n4 -> k1
572 // NOTE: had to use merge_v3 or will spit out compilation errors
573 constexpr auto acc_thread_desc_k0_m_k1 = transform_tensor_descriptor(
574 acc_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
580
581 // A1 matrix in AccVGPR
582 // N2 num_groups_per_blk, N3 num_input_blks, N4 group_size
583 constexpr auto AccN3 =
584 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLength(I6);
585
586 constexpr auto A1ThreadSlice_K0_M_K1 =
588
589 constexpr auto A1ThreadSliceK0 = A1ThreadSlice_K0_M_K1[I0];
590 constexpr auto A1ThreadSliceM = A1ThreadSlice_K0_M_K1[I1];
591 constexpr auto A1ThreadSliceK1 = A1ThreadSlice_K0_M_K1[I2];
592
593 // A1 matrix blockwise copy
594#if defined(__gfx11__)
595 constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor_packed(
596 make_tuple(A1ThreadSliceK0, A1ThreadSliceM, Number<A1ThreadSliceK1 * 2>{}));
598 FloatGemmAcc,
599 FloatAB,
600 decltype(acc_thread_desc_k0_m_k1),
601 decltype(a1_thread_desc_k0_m_k1),
605 2,
606 n4,
607 0x76543210,
608 0xfedcba98,
609 false>{make_tuple(0, 0, 0)};
610 static_assert(n4 == A1ThreadSliceK1);
611#else
612 constexpr auto a1_thread_desc_k0_m_k1 = make_naive_tensor_descriptor(
613 A1ThreadSlice_K0_M_K1,
614 make_tuple(A1ThreadSliceM * A1ThreadSliceK1, A1ThreadSliceK1, I1));
615 auto a1_blockwise_copy = ThreadwiseTensorSliceTransfer_StaticToStatic<
616 FloatGemmAcc,
617 FloatAB,
618 decltype(acc_thread_desc_k0_m_k1),
619 decltype(a1_thread_desc_k0_m_k1),
623 2,
625#endif
626
627 // B1 matrix in LDS memory, dst of blockwise copy
628 constexpr auto b1_block_desc_bk0_n_bk1 = GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1();
629
630 // B1 matrix blockwise copy
631 auto b1_blockwise_copy =
633 BElementwiseOperation,
637 B1BlockTransferThreadClusterLengths_BK0_N_BK1,
638 B1BlockTransferThreadClusterArrangeOrder,
639 FloatAB,
640 FloatAB,
641 decltype(b1_grid_desc_bk0_n_bk1),
642 decltype(b1_block_desc_bk0_n_bk1),
643 B1BlockTransferSrcAccessOrder,
645 B1BlockTransferSrcVectorDim,
646 2,
647 B1BlockTransferSrcScalarPerVector,
648 B1BlockTransferDstScalarPerVector_BK1,
649 1,
650 1,
651 B1ThreadTransferSrcResetCoordinateAfterRun,
652 true, // DstResetCoord
653 NumGemmKPrefetchStage>(
654 b1_grid_desc_bk0_n_bk1,
655 make_multi_index(0, gemm1_n_block_data_idx_on_grid, 0),
656 b1_element_op,
657 b1_block_desc_bk0_n_bk1,
658 make_multi_index(0, 0, 0),
660
662 a1_thread_desc_k0_m_k1.GetElementSpaceSize());
663
664 // reuse LDS space for gemm0's b_block_buf
666 static_cast<FloatAB*>(p_shared) + SharedMemTrait::b1_block_space_offset,
667 b1_block_desc_bk0_n_bk1.GetElementSpaceSize());
668
669 // selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
670 // selected_mfma.k_per_blk <= Gemm1KPack
671 //
672 // Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
673 // multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
674 // Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
675 // with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
676 // cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
677 // therefore we may just as well assign Gemm1KPack = group_size
678#if defined(__gfx11__)
679 constexpr index_t Gemm1KPack =
681#else
682 constexpr index_t Gemm1KPack =
684#endif
685
686 auto gemm1_blockwise_gemm = BlockwiseGemmXdlops_v2<
687 BlockSize,
688 FloatAB,
689 FloatGemmAcc,
690 decltype(a1_thread_desc_k0_m_k1),
691 decltype(b1_block_desc_bk0_n_bk1),
692 decltype(MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(a1_thread_desc_k0_m_k1)),
693 decltype(MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(b1_block_desc_bk0_n_bk1)),
694 MPerBlock,
695 Gemm1NPerBlock,
696 Gemm1KPerBlock,
697 MPerXdl,
698 NPerXdl,
699 MXdlPerWave,
700 Gemm1NXdlPerWave,
701 Gemm1KPack,
702 true, // TransposeC
703 Gemm1KPack, // AMmaKStride
704 Gemm1KPack *
706 // BMmaKStride
707 make_tuple(0, 0, 0, 0)}; // A_origin
708
709 auto acc1_thread_buf = gemm1_blockwise_gemm.GetCThreadBuffer();
710
711 //
712 // Blockwise softmax
713 //
715 static_cast<FloatGemmAcc*>(p_shared) + SharedMemTrait::reduction_space_offset,
717
718 // get acc0 8D thread cluster
719 constexpr auto thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4 =
720 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths() /
721 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
722 constexpr auto tm0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I0);
723 constexpr auto tn0 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I1);
724 constexpr auto tm1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I2);
725 constexpr auto tn1 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I3);
726 constexpr auto tm2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I4);
727 constexpr auto tn2 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I5);
728 constexpr auto tn3 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I6);
729 constexpr auto tn4 = thread_cluster_m0_n0_m1_n1_m2_n2_n3_n4.At(I7);
730
731 // get acc0 thread map
732 constexpr auto m0_n_m1_to_m_n_adaptor = make_single_stage_tensor_adaptor(
737 constexpr auto threadid_to_m0_n_m1_adaptor = make_single_stage_tensor_adaptor(
739 make_merge_transform(make_tuple(tm0 * tm1, tn0 * tn1 * tn2 * tn3 * tn4, tm2))),
742 const auto threadid_to_m_n_thread_cluster_adaptor =
743 chain_tensor_adaptors(m0_n_m1_to_m_n_adaptor, threadid_to_m0_n_m1_adaptor);
744
745 // get acc0 2D thread cluster & 2D thread slice
746 constexpr auto thread_cluster_desc_m_n = make_naive_tensor_descriptor_packed(
747 make_tuple(tm0 * tm1 * tm2, tn0 * tn1 * tn2 * tn3 * tn4));
748 constexpr auto thread_slice_desc_m_n =
749 make_naive_tensor_descriptor_packed(make_tuple(m0 * m1 * m2, n0 * n1 * n2 * n3 * n4));
750
751 auto blockwise_softmax = BlockwiseSoftmax<BlockSize,
752 FloatGemmAcc,
753 decltype(threadid_to_m_n_thread_cluster_adaptor),
754 decltype(thread_cluster_desc_m_n),
755 decltype(thread_slice_desc_m_n)>{};
756
757 const index_t num_gemm1_k_block_outer_loop =
758 b_grid_desc_bk0_n_bk1.GetLength(I1) / NPerBlock * Gemm0NWaves;
759 constexpr index_t num_gemm1_k_block_inner_loop = NPerBlock / Gemm1KPerBlock / Gemm0NWaves;
760
761 // Initialize C
762 StaticBuffer<AddressSpaceEnum::Vgpr, FloatGemmAcc, acc1_thread_buf.Size(), true>
763 c_thread_buf;
764 c_thread_buf.Clear();
765
766 // Initialize running sum and max of exponentiating row vectors
767 using SoftmaxBuf = typename decltype(blockwise_softmax)::BufferType;
768 SoftmaxBuf running_sum, running_sum_new, running_max, running_max_new;
769 running_sum = 0;
770 running_sum_new = 0;
772 running_max_new = NumericLimits<FloatGemmAcc>::Lowest();
773
774 // gemm1 K loop
775 index_t gemm1_k_block_outer_index = 0;
776 do
777 {
778 auto n_block_data_idx_on_grid =
779 __builtin_amdgcn_readfirstlane(gemm1_k_block_outer_index * NPerBlock);
780 if(c0_matrix_mask.IsTileSkippable(
781 m_block_data_idx_on_grid, n_block_data_idx_on_grid, MPerBlock, NPerBlock))
782 {
783 continue;
784 }
785 // gemm0
786 gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
787 a_block_desc_ak0_m_ak1,
788 a_blockwise_copy,
789 a_grid_buf,
790 a_block_buf,
791 a_block_slice_copy_step,
792 b_grid_desc_bk0_n_bk1,
793 b_block_desc_bk0_n_bk1,
794 b_blockwise_copy,
795 b_grid_buf,
796 b_block_buf,
797 b_block_slice_copy_step,
798 blockwise_gemm,
799 acc_thread_buf,
800 num_k_block_main_loop);
801
802 // do MNK padding or upper triangular masking
803 if constexpr(MaskOutUpperTriangle || PadN)
804 {
805 // 8d thread_desc in thread scope
806 constexpr auto c_thread_lengths =
807 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
808
809 // 8d block_desc in block scope
810 constexpr auto c_block_lengths =
811 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4().GetLengths();
812
813 constexpr auto M0 = c_block_lengths[I0];
814 constexpr auto N0 = c_block_lengths[I1];
815 constexpr auto M1 = c_block_lengths[I2];
816 constexpr auto N1 = c_block_lengths[I3];
817 constexpr auto M2 = c_block_lengths[I4];
818 constexpr auto N2 = c_block_lengths[I5];
819 constexpr auto N3 = c_block_lengths[I6];
820 constexpr auto N4 = c_block_lengths[I7];
821
822 // works like multi-dimension static_for (static_ford), but provides both the linear
823 // index as well as n-d index
824 using Acc0TileIterator = SpaceFillingCurve<
825 decltype(c_thread_lengths),
826 typename arithmetic_sequence_gen<0, c_thread_lengths.Size(), 1>::type,
827 typename uniform_sequence_gen<c_thread_lengths.Size(), 1>::type,
828 false>; // SnakeCurved
829
830 auto acc0_thread_origin = blockwise_gemm.CalculateCThreadOriginDataIndex8D(
832
833 constexpr auto block_idx_to_m_n_adaptor = make_single_stage_tensor_adaptor(
835 make_unmerge_transform(make_tuple(N0, N1, N2, N3, N4))),
838
839 static_for<0, Acc0TileIterator::GetNumOfAccess(), 1>{}([&](auto i) {
840 auto acc0_thread_idx = Acc0TileIterator::GetIndex(i) + acc0_thread_origin;
841 auto m_local =
842 block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I0];
843 auto n_local =
844 block_idx_to_m_n_adaptor.CalculateBottomIndex(acc0_thread_idx)[I1];
845 auto m_global = m_local + m_block_data_idx_on_grid;
846 auto n_global = n_local + n_block_data_idx_on_grid;
847 if(c0_matrix_mask.IsMaskedElement(m_global, n_global))
848 {
849 acc_thread_buf(i) = -ck::NumericLimits<float>::Infinity();
850 }
851 else
852 {
853 acc_element_op(acc_thread_buf(i), acc_thread_buf[i]);
854 }
855 });
856 }
857 else
858 {
859 static_for<0, acc_thread_buf.Size(), 1>{}(
860 [&](auto i) { acc_element_op(acc_thread_buf(i), acc_thread_buf[i]); });
861 }
862
863 block_sync_lds(); // wait for lds read in gemm0 blockwise gemm
864
865 // softmax
866 SoftmaxBuf& max = blockwise_softmax.max_value_buf;
867 SoftmaxBuf& sum = blockwise_softmax.sum_value_buf;
868
869 blockwise_softmax.Run(acc_thread_buf, workspace_buf);
870
871 // TODO: may convert to log domain
872 running_max_new = mathext::max(max, running_max);
873 running_sum_new = mathext::exp(running_max - running_max_new) * running_sum +
874 mathext::exp(max - running_max_new) * sum;
875
876 // gemm1
877 {
878 // TODO: explore using dynamic buffer for a1 thread buffer
879 // For a1_blockwise_copy, the goal is to satisfy pipeline requirements RunRead(),
880 // RunWrite(), and MoveSliceWindow(). But it is impossible to implement given that
881 // the A1 source buffer is static buffer holding the output of first GEMM and
882 // requires constexpr offset by design. Therefore, we pass tensor coordinate offset
883 // explicitly in Run() below.
884
885 // Initialize acc1
886 acc1_thread_buf.Clear();
887
888 // preload data into LDS
889 b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
890
891 b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
892 b1_block_slice_copy_step);
893
894 block_sync_lds(); // wait for reduction LDS read
895
896 b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
897
898 // main body
899 if constexpr(num_gemm1_k_block_inner_loop > 1)
900 {
901 static_for<0, num_gemm1_k_block_inner_loop - 1, 1>{}([&](auto i) {
902 a1_blockwise_copy.Run(acc_thread_desc_k0_m_k1,
904 acc_thread_buf,
905 a1_thread_desc_k0_m_k1,
906 make_tuple(I0, I0, I0),
907 a1_thread_buf);
908
909 b1_blockwise_copy.RunRead(b1_grid_desc_bk0_n_bk1, b1_grid_buf);
910
912
913 gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
914
916
917 b1_blockwise_copy.MoveSrcSliceWindow(b1_grid_desc_bk0_n_bk1,
918 b1_block_slice_copy_step);
919
920 b1_blockwise_copy.RunWrite(b1_block_desc_bk0_n_bk1, b1_block_buf);
921 });
922 }
923 // tail
924 {
925 a1_blockwise_copy.Run(
926 acc_thread_desc_k0_m_k1,
928 Number<(num_gemm1_k_block_inner_loop - 1) * A1ThreadSliceK0>{}, I0, I0),
929 acc_thread_buf,
930 a1_thread_desc_k0_m_k1,
931 make_tuple(I0, I0, I0),
932 a1_thread_buf);
933
935
936 gemm1_blockwise_gemm.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
937 }
938 } // end gemm1
939
940 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
941 gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
942 constexpr auto cm0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I0);
943 constexpr auto cn0 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I1);
944 constexpr auto cm1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I2);
945 constexpr auto cn1 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I3);
946 constexpr auto cm2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I4);
947 constexpr auto cn2 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I5);
948 constexpr auto cn3 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I6);
949 constexpr auto cn4 = c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4.GetLength(I7);
950 constexpr auto c_thread_slice_desc_m_n = make_naive_tensor_descriptor_packed(
951 make_tuple(cm0 * cm1 * cm2, cn0 * cn1 * cn2 * cn3 * cn4));
952 constexpr auto c_thread_buf_slice_m = c_thread_slice_desc_m_n.GetLength(I0);
953 constexpr auto c_thread_buf_slice_n = c_thread_slice_desc_m_n.GetLength(I1);
954
957 auto I = Number<c_thread_slice_desc_m_n.CalculateOffset(make_tuple(iM, iN))>{};
958 FloatGemmAcc acc1 = acc1_thread_buf[I]; // P*V
959 FloatGemmAcc c = c_thread_buf[I]; // O
960 FloatGemmAcc c_new =
961 (running_sum[iM] * math::exp(running_max[iM] - running_max_new[iM]) * c +
962 math::exp(max[iM] - running_max_new[iM]) * acc1) /
963 running_sum_new[iM]; // Formula by Dao et al.,
964 // https://arxiv.org/pdf/2205.14135v2.pdf section 3.1
965
966 c_thread_buf(I) = c_new; // O_new
967 });
968 });
969
970 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc_ak0_m_ak1,
971 a_block_reset_copy_step); // rewind K
972 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc_bk0_n_bk1,
973 b_block_reset_copy_step); // rewind K and step N
974
975 // update before next j iteration
976 running_max = running_max_new;
977 running_sum = running_sum_new;
978
979 block_sync_lds(); // wait for gemm1 LDS read
980 } while(++gemm1_k_block_outer_index < num_gemm1_k_block_outer_loop); // end j loop
981
982 // shuffle C and write out
983 {
984 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
985 Gemm1NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
986 "wrong!");
987
988 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
989 constexpr index_t NWave = Gemm1NPerBlock / (Gemm1NXdlPerWave * NPerXdl);
990
991 // TODO: hacky, fix it!
992 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
993 gemm1_blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
994
995 // TODO: hacky, fix it!
996 // c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp is only used to get lengths
997 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
998 gemm1_blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
999
1000 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I0);
1001 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I1);
1002 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I2);
1003 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I3);
1004 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I4);
1005 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I5);
1006 constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I6);
1007 constexpr auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(I7);
1008
1009 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1011
1012 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1013 static_cast<FloatCShuffle*>(p_shared),
1014 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1015
1016 constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4 = transform_tensor_descriptor(
1017 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1018 make_tuple(
1021 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1022 M1, // M1 = MWave
1023 M2)), // M2 = MPerXdl
1026 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1027 N1, // N1 = NWave
1028 N2, // N2 * N3 * N4 = NPerXdl
1029 N3,
1030 N4))),
1032 make_tuple(
1034
1035 // calculate origin of thread output tensor on global memory
1036 // blockwise GEMM c matrix starting index
1037 const auto c_thread_mtx_on_block =
1038 gemm1_blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1039
1040 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1041 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1042
1043 const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
1048
1049 const auto m_thread_data_on_block_idx =
1050 m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
1051 make_multi_index(m_thread_data_on_block));
1052
1053 const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
1055 make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3, N4))),
1058
1059 const auto n_thread_data_on_block_idx =
1060 n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
1061 make_multi_index(n_thread_data_on_block));
1062
1063 // shuffle: threadwise copy C from VGPR to LDS
1064 auto c_thread_copy_vgpr_to_lds =
1066 FloatCShuffle,
1067 decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1068 decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1070 Sequence<CShuffleMXdlPerWavePerShuffle,
1071 CShuffleNXdlPerWavePerShuffle,
1072 I1,
1073 I1,
1074 I1,
1075 N2,
1076 I1,
1077 N4>,
1079 7,
1080 1,
1082 1,
1083 true>{
1084 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1086 0,
1087 m_thread_data_on_block_idx[I1],
1088 n_thread_data_on_block_idx[I1],
1089 m_thread_data_on_block_idx[I2],
1090 n_thread_data_on_block_idx[I2],
1091 n_thread_data_on_block_idx[I3],
1092 n_thread_data_on_block_idx[I4]),
1094
1095 // shuffle: blockwise copy C from LDS to global
1096 auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1097 ThisThreadBlock, // ThreadGroup
1098 CElementwiseOperation, // ElementwiseOperation,
1099 CGlobalMemoryDataOperation, // DstInMemOp,
1100 Sequence<1,
1101 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1102 1,
1103 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1104 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1105 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1106 FloatCShuffle, // typename SrcData,
1107 FloatC, // typename DstData,
1108 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1109 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1110 Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1111 3, // index_t VectorDim,
1112 CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1113 true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1114 false> // bool ThreadTransferDstResetCoordinateAfterRun>
1115 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1116 make_multi_index(0, 0, 0, 0),
1117 c_grid_desc_mblock_mperblock_nblock_nperblock,
1118 make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
1119 c_element_op};
1120
1121 // space filling curve for threadwise C in VGPR
1122 constexpr auto sfc_c_vgpr =
1125 Sequence<CShuffleMXdlPerWavePerShuffle,
1126 CShuffleNXdlPerWavePerShuffle,
1127 1,
1128 1,
1129 1,
1130 N2,
1131 1,
1132 N4>>{};
1133
1134 // space filling curve for shuffled blockwise C in global mem
1135 constexpr auto sfc_c_global =
1138 Sequence<1,
1139 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1140 1,
1141 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1142
1143 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1144
1145 static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1146
1147 static_for<0, num_access, 1>{}([&](auto access_id) {
1148 // make sure it's safe to write to LDS
1150
1151 // each thread write its data from VGPR to LDS
1152 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1153 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1154 c_thread_buf,
1155 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1156 c_shuffle_block_buf);
1157
1158 // make sure it's safe to read from LDS
1160
1161 // each block copy its data from LDS to global
1162 c_shuffle_block_copy_lds_to_global.Run(
1163 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1164 c_shuffle_block_buf,
1165 c_grid_desc_mblock_mperblock_nblock_nperblock,
1166 c_grid_buf);
1167
1168 if constexpr(access_id < num_access - 1)
1169 {
1170 constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1171
1172 // move on C
1173 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1174 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1175 }
1176 });
1177 }
1178 }
1179};
1180
1181} // namespace ck
__host__ T exp(T x)
Definition math_v2.hpp:391
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
__host__ __device__ constexpr auto exp(const Tuple< Xs... > &x)
Definition statically_indexed_array_multi_index.hpp:124
__host__ __device__ constexpr auto max(const Tuple< Xs... > &x, const Y &y)
Definition statically_indexed_array_multi_index.hpp:134
@ MaskOutUpperTriangle
Definition masking_specialization.hpp:13
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0 &adaptor0, const TensorAdaptor1 &adaptor1)
Definition tensor_description/tensor_adaptor.hpp:245
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
Definition block_to_ctile_map.hpp:261
Blockwise gemm.
Definition blockwise_gemm_xdlops.hpp:690
Blockwise softmax.
Definition blockwise_softmax.hpp:32
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:327
static constexpr auto a_block_space_offset
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:345
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::SharedMemTrait::b_block_space_size_aligned
static constexpr auto b_block_space_size_aligned
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:340
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::SharedMemTrait::reduction_space_size_aligned
static constexpr index_t reduction_space_size_aligned
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:350
static constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:356
static constexpr auto a_block_desc_ak0_m_ak1
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:329
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::SharedMemTrait::b1_block_space_size_aligned
static constexpr auto b1_block_space_size_aligned
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:342
static constexpr auto b1_block_desc_bk0_n_bk1
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:333
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::SharedMemTrait::b1_block_space_offset
static constexpr auto b1_block_space_offset
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:347
static constexpr auto max_lds_align
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:336
static constexpr auto b_block_desc_bk0_n_bk1
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:331
static constexpr auto b_block_space_offset
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:346
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::SharedMemTrait::reduction_space_offset
static constexpr auto reduction_space_offset
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:353
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::SharedMemTrait::a_block_space_size_aligned
static constexpr auto a_block_space_size_aligned
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:338
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::SharedMemTrait::c_block_space_size
static constexpr auto c_block_space_size
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:358
Gridwise gemm + softmax + gemm fusion.
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:87
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:293
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::I4
static constexpr auto I4
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:95
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
__host__ static __device__ constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:163
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >::GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
__host__ static __device__ constexpr auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:180
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >::MakeDefaultBlock2CTileMap
__host__ static __device__ constexpr auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:313
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 &b1_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:231
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::B1K0
static constexpr auto B1K0
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:111
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >::GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1
__host__ static __device__ constexpr auto GetB1BlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:171
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::Gemm0NWaves
static constexpr auto Gemm0NWaves
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:108
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::I3
static constexpr auto I3
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:94
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::I0
static constexpr auto I0
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:91
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::I1
static constexpr auto I1
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:92
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >::IsValidCompilationParameter
static __device__ bool constexpr IsValidCompilationParameter()
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:214
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::BK0
static constexpr auto BK0
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:103
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::AK1
static constexpr auto AK1
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:104
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >::MakeGemm1BMmaTileDescriptor_N0_N1_N2_K
__host__ static __device__ constexpr auto MakeGemm1BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:148
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >::MakeGemm0AMmaTileDescriptor_M0_M1_M2_K
__host__ static __device__ constexpr auto MakeGemm0AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:121
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::GridwiseGemmPipe
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVersion::v1, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:116
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::ThisThreadBlock
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:114
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::BK1
static constexpr auto BK1
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:105
static __device__ void Run(const ADataType *__restrict__ p_a_grid, const ADataType *__restrict__ p_b_grid, const ADataType *__restrict__ p_b1_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const AccElementwiseOperation &acc_element_op, const B1ElementwiseOperation &b1_element_op, const CElementwiseOperation &c_element_op, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 &b1_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMap &block_2_ctile_map, const C0MatrixMask &c0_matrix_mask)
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:363
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::I5
static constexpr auto I5
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:96
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::B1K1
static constexpr auto B1K1
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:112
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >::MakeGemm1AMmaTileDescriptor_M0_M1_M2_K
__host__ static __device__ constexpr auto MakeGemm1AMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:141
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::Gemm0MWaves
static constexpr auto Gemm0MWaves
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:107
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::I6
static constexpr auto I6
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:97
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::DefaultBlock2CTileMap
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))> DefaultBlock2CTileMap
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:323
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::I7
static constexpr auto I7
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:98
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >::GetSharedMemoryNumberOfByte
__host__ static __device__ constexpr index_t GetSharedMemoryNumberOfByte()
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:195
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::I2
static constexpr auto I2
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:93
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >::MakeGemm0BMmaTileDescriptor_N0_N1_N2_K
__host__ static __device__ constexpr auto MakeGemm0BMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:131
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::AK0
static constexpr auto AK0
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:102
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >< math::max(MXdlPerWave64, 1)>::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:319
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
__host__ static __device__ constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:155
ck::GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle< ADataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, AccElementwiseOperation, B1ElementwiseOperation, CElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, B1GridDesc_BK0_N_BK1, CGridDesc_M_N, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, Gemm1NPerBlock, Gemm1KPerBlock, AK1, BK1, B1K1, MPerXDL, NPerXDL, MXdlPerWave_, NXdlPerWave, Gemm1NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, true, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, true, BBlockLdsExtraN, B1BlockTransferThreadClusterLengths_BK0_N_BK1, B1BlockTransferThreadClusterArrangeOrder, B1BlockTransferSrcAccessOrder, B1BlockTransferSrcVectorDim, B1BlockTransferSrcScalarPerVector, B1BlockTransferDstScalarPerVector_BK1, false, B1BlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched, matrix_padder.PadN, MaskOutUpperTriangle >::CalculateHasMainKBlockLoop
__host__ static __device__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp:285
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
static constexpr auto selected_mfma
Definition xdlops_gemm.hpp:1757
__host__ static __device__ constexpr T Lowest()
Definition numeric_limits.hpp:312
__host__ static __device__ constexpr T Infinity()
Definition numeric_limits.hpp:317
Definition utility/sequence.hpp:43
Definition tensor_space_filling_curve.hpp:20
Definition static_buffer.hpp:16
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition threadwise_tensor_slice_transfer.hpp:1877
Threadwise data transfer.
Definition threadwise_tensor_slice_transfer.hpp:1720
Definition threadwise_tensor_slice_transfer.hpp:39
Definition xdlops_gemm.hpp:1821
static constexpr auto K0PerXdlops
Definition xdlops_gemm.hpp:2201
Definition utility/sequence.hpp:256
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
Definition utility/sequence.hpp:289