gridwise_moe_mx_gemm_bns.hpp Source File

gridwise_moe_mx_gemm_bns.hpp Source File#

Composable Kernel: gridwise_moe_mx_gemm_bns.hpp Source File
gridwise_moe_mx_gemm_bns.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7#include "ck/utility/env.hpp"
18
20
21#define DEBUG_LOG 0
22
23namespace ck {
24
25// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
26// kernel function Blockers:
27// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
28// two lds chunks.
29// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
30// buffer when we declare __shared__ inside blkgemmpipe
31
33{
34 gelu_and_mul = 0,
35 silu_and_mul = 1
36};
37
38template <typename GridwiseGemm,
39 bool HasMainKBlockLoop,
40 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
41 index_t MinimumOccupancy = 1,
43__global__ void
44#if CK_USE_LAUNCH_BOUNDS
45__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
46#endif
47 // __attribute__((amdgpu_waves_per_eu(1, 1)))
48 kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
49{
50#if defined(__gfx9__)
51 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
52 {
53 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
54
55 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
56
57 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
58 karg.p_sorted_token_ids,
59 karg.p_sorted_expert_ids,
60 karg.p_max_token_id,
61 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
62 karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
63 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
64 karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
65 karg.p_ds_grid,
66 karg.p_c_grid,
67 p_shared,
68 karg,
69 karg.a_element_op,
70 karg.b_element_op,
71 karg.c_element_op);
72 }
73#else
74 ignore = karg;
75#endif // end of if (defined(__gfx9__))
76}
77
78#if 0
79template <typename GridwiseGemm,
80 bool HasMainKBlockLoop,
81 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
82 index_t MinimumOccupancy = 1,
84__global__ void
85#if CK_USE_LAUNCH_BOUNDS
86__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
87#endif
88 // __attribute__((amdgpu_waves_per_eu(1, 1)))
89 kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
90{
91#if defined(__gfx9__)
92 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
93 {
94 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
95 __shared__ char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
96
97 // auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
98
99 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
100 karg.p_sorted_token_ids,
101 karg.p_sorted_expert_ids,
102 karg.p_max_token_id,
103 karg.p_a_grid,
104 karg.p_a_scale_grid,
105 karg.p_b_grid,
106 karg.p_b_scale_grid,
107 karg.p_ds_grid,
108 karg.p_c_grid,
109 p_shared,
110 p_shared1,
111 karg,
112 karg.a_element_op,
113 karg.b_element_op,
114 karg.c_element_op);
115 }
116#else
117 ignore = karg;
118#endif // end of if (defined(__gfx9__))
119}
120#endif
121
122template <typename ALayout,
123 typename BLayout,
124 typename DsLayout,
125 typename CLayout,
126 typename ADataType,
127 typename AScaleDataType,
128 typename BDataType,
129 typename BScaleDataType,
130 typename AccDataType,
131 typename CShuffleDataType,
132 typename DsDataType,
133 typename CDataType,
134 typename AElementwiseOperation,
135 typename BElementwiseOperation,
136 typename CElementwiseOperation,
138 index_t ScaleBlockSize,
139 index_t BlockSize,
140 index_t MPerBlock,
141 index_t NPerBlock,
142 index_t KPerBlock,
143 index_t AK1Value,
144 index_t BK1Value,
145 index_t MPerXdl,
146 index_t NPerXdl,
147 index_t MXdlPerWave,
148 index_t NXdlPerWave,
149 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
150 typename ABlockTransferThreadClusterArrangeOrder,
151 typename ABlockTransferSrcAccessOrder,
152 index_t ABlockTransferSrcVectorDim,
153 index_t ABlockTransferSrcScalarPerVector,
154 index_t ABlockTransferDstScalarPerVector_AK1,
155 bool AThreadTransferSrcResetCoordinateAfterRun,
156 index_t ABlockLdsExtraM,
157 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
158 typename BBlockTransferThreadClusterArrangeOrder,
159 typename BBlockTransferSrcAccessOrder,
160 index_t BBlockTransferSrcVectorDim,
161 index_t BBlockTransferSrcScalarPerVector,
162 index_t BBlockTransferDstScalarPerVector_BK1,
163 bool BThreadTransferSrcResetCoordinateAfterRun,
164 index_t BBlockLdsExtraN,
165 index_t CShuffleMXdlPerWavePerShuffle,
166 index_t CShuffleNXdlPerWavePerShuffle,
167 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
168 typename CDEShuffleBlockTransferScalarPerVectors,
171 index_t ActivationOperation = 0,
172 bool NSwizzle = false,
173 bool IsInputGemm = true,
174 bool MulRoutedWeight = true,
175 typename IndexType = index_t,
176 typename ComputeTypeA = ADataType,
177 typename ComputeTypeB = BDataType>
179{
180 using LDSTypeA = ADataType;
181 using LDSTypeB = BDataType;
182
183 static constexpr auto I0 = Number<0>{};
184 static constexpr auto I1 = Number<1>{};
185 static constexpr auto I2 = Number<2>{};
186 static constexpr auto I3 = Number<3>{};
187 static constexpr auto I4 = Number<4>{};
188 static constexpr auto I5 = Number<5>{};
189 static constexpr auto I6 = Number<6>{};
190 static constexpr auto I7 = Number<7>{};
191 static constexpr auto I8 = Number<8>{};
192 static constexpr auto I9 = Number<9>{};
193
195 CDEShuffleBlockTransferScalarPerVectors{}[I0];
196 // K1 should be Number<...>
197 static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
198 static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
199 static constexpr auto AK1Number = Number<AK1Value>{};
200 static constexpr auto BK1Number = Number<BK1Value>{};
201
202 static constexpr index_t NumDTensor = DsDataType::Size();
203
204 static constexpr auto MXdlPack = 2;
205 static constexpr auto NXdlPack = 2;
206 static constexpr auto KXdlPack = 2;
207
210
211 static constexpr bool is_single_rate_mfma = false;
212 static constexpr auto is_scale_mfma = true;
213 using mfma_selector = MfmaSelector<ComputeTypeA,
214 MPerXdl,
215 NPerXdl,
216 ComputeTypeB,
219 static constexpr index_t KPack = math::max(
221
222 // static constexpr index_t NumTokens = 1;
223 static constexpr index_t SortedTileSize = MPerBlock;
224
225 static constexpr auto MakeDsGridPointer()
226 {
227 return generate_tuple(
228 [&](auto i) {
229 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
230
231 return static_cast<const DDataType*>(nullptr);
232 },
234 }
235
236 using DsGridPointer = decltype(MakeDsGridPointer());
237
239
240 __host__ static auto CalculateGridSize(index_t M, index_t N)
241 {
242 const index_t nblock = math::integer_divide_ceil(N, NPerBlock);
243 const index_t mblock = math::integer_divide_ceil(M, MPerBlock);
244 const index_t gridx = NSwizzle ? nblock * mblock : nblock;
245 const index_t gridy = NSwizzle ? 1 : mblock;
246
247 return std::make_tuple(gridx, gridy, 1);
248 }
249
250 __host__ static auto CalculateMPadded(index_t M)
251 {
252 return math::integer_least_multiple(M, MPerBlock);
253 }
254
255 __host__ static auto CalculateNPadded(index_t N)
256 {
257 return math::integer_least_multiple(N, NPerBlock);
258 }
259
260 __host__ static auto CalculateKPadded(index_t K)
261 {
262 return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
263 }
264
265 __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
266 {
267 auto K_t = K_Batch * KPerBlock;
268 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
269 }
270
271 __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
272 {
273 auto K_t = K_Batch * KPerBlock;
274 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
275 }
276
277 __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
278 {
279 auto K_t = K_Batch * KPerBlock;
280 return (K + K_t - 1) / K_t * KPerBlock;
281 }
282
283 __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
284 {
285 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
286 auto K_t = K_Batch * KReadVec;
287 return (K + K_t - 1) / K_t * KReadVec;
288 }
289
290 __host__ static auto CalculateMBlock(index_t M)
291 {
292 return math::integer_divide_ceil(M, MPerBlock);
293 }
294
295 __host__ static auto CalculateNBlock(index_t N)
296 {
297 return math::integer_divide_ceil(N, NPerBlock);
298 }
299
300 template <index_t MNXdlPerWave,
301 index_t MNWaves,
302 index_t MNXdlPack,
303 index_t MNPerXdl,
304 typename TileDesc_K0_MN_K1>
305 __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
306 {
307 constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
308 constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
309
311 TileDesc_K0_MN_K1{},
316 Number<MNPerXdl>{}))),
319 }
320
321 __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
322 IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
323 {
324 const auto a_grid_desc_mraw_kraw = [&]() {
326 {
327 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
328 }
330 {
331 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
332 }
333 }();
334
335 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
336
337 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
338 GemmSpec == GemmSpecialization::MNKPadding)
339 {
340 // pad both M and K
341 const auto a_grid_desc_m_k =
342 transform_tensor_descriptor(a_grid_desc_mraw_kraw,
344 make_right_pad_transform(K, KPad - K)),
347
348 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
349 a_grid_desc_m_k,
354
355 return a_grid_desc_ak0_m_ak1;
356 }
357 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
358 GemmSpec == GemmSpecialization::MNPadding)
359 {
360 // pad M, but not K
361 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
362 a_grid_desc_mraw_kraw,
364 make_right_pad_transform(M, MPad - M)),
367
368 return a_grid_desc_ak0_m_ak1;
369 }
370 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
371 GemmSpec == GemmSpecialization::NKPadding)
372 {
373 // pad K, but not M
374 const auto a_grid_desc_m_k = transform_tensor_descriptor(
375 a_grid_desc_mraw_kraw,
379
380 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
381 a_grid_desc_m_k,
386
387 return a_grid_desc_ak0_m_ak1;
388 }
389 else
390 {
391 // not pad M or K
392 const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
393 a_grid_desc_mraw_kraw,
398
399 return a_grid_desc_ak0_m_ak1;
400 }
401 }
402
403 __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
404 index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
405 {
406 const auto b_grid_desc_nraw_kraw = [&]() {
408 {
409 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
410 }
412 {
413 return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
414 }
415 }();
416
417 using GemmSpecialization = tensor_operation::device::GemmSpecialization;
418
420 GemmSpec != GemmSpecialization::Default),
421 "pk_i4_t does not support padding");
423 GemmSpec != GemmSpecialization::Default),
424 "f4x2_pk_t does not support padding");
425
426 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
427 GemmSpec == GemmSpecialization::MNKPadding)
428 {
429 // pad both N and K
430 const auto b_grid_desc_n_k =
431 transform_tensor_descriptor(b_grid_desc_nraw_kraw,
433 make_right_pad_transform(K, KPad - K)),
436
437 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
438 b_grid_desc_n_k,
443
444 return b_grid_desc_bk0_n_bk1;
445 }
446 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
447 GemmSpec == GemmSpecialization::MNPadding)
448 {
449 // pad N, but not K
450 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
451 b_grid_desc_nraw_kraw,
453 make_right_pad_transform(N, NPad - N)),
456
457 return b_grid_desc_bk0_n_bk1;
458 }
459 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
460 GemmSpec == GemmSpecialization::MKPadding)
461 {
462 // pad K, but not N
463 const auto b_grid_desc_n_k = transform_tensor_descriptor(
464 b_grid_desc_nraw_kraw,
468
469 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
470 b_grid_desc_n_k,
475
476 return b_grid_desc_bk0_n_bk1;
477 }
478 else
479 {
480 // not pad N or K
481 const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
482 b_grid_desc_nraw_kraw,
487
488 return b_grid_desc_bk0_n_bk1;
489 }
490 }
491
492 template <typename ABlockDesc_AK0_M_AK1>
493 __host__ __device__ static constexpr auto
494 MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1&)
495 {
496 constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
497
499 ABlockDesc_AK0_M_AK1{});
500 }
501
502 template <typename BBlockDesc_BK0_N_BK1>
503 __host__ __device__ static constexpr auto
504 MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1&)
505 {
506 constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
507
509 BBlockDesc_BK0_N_BK1{});
510 }
511
512 template <typename ELayout>
513 __host__ __device__ static auto MakeCGridDescriptor_M_N(
514 IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
515 {
516 const auto c_grid_desc_mraw_nraw = [&]() {
518 {
519 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
520 }
522 {
523 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
524 }
525 }();
526
527 // pad M and N
528 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
530 make_right_pad_transform(N, NPad - N)),
533 }
534
535 template <typename DLayout>
536 __host__ __device__ static auto
538 {
539 const auto c_grid_desc_mraw_nraw = [&]() {
541 {
542 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I0));
543 }
545 {
546 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I0, StrideC));
547 }
548 }();
549
550 // pad M and N
551 return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
553 make_right_pad_transform(N, NPad - N)),
556 }
557
558 __host__ __device__ static auto MakeDsGridDescriptor_M_N(
559 index_t M, index_t MPad, index_t N, index_t NPad, std::array<index_t, NumDTensor> StrideDs)
560 {
561 return generate_tuple(
562 [&](auto i) {
563 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
564 return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
565 },
567 }
568
569 template <typename DsGridDesc>
571 const DsGridDesc& ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
572 {
573 return generate_tuple(
574 [&](auto i) {
576 ds_grid_desc_m_n[i], MBlock, NBlock);
577 },
579 }
580
581 struct Problem
582 {
583 __host__ Problem(index_t NumTokens_,
584 index_t TopK_,
585 index_t M_,
586 index_t N_,
587 index_t K_,
588 index_t StrideA_,
589 index_t StrideScaleA_,
590 index_t StrideB_,
591 index_t StrideScaleB_,
592 std::array<index_t, NumDTensor> StrideDs_,
593 index_t StrideC_,
594 index_t KBatch_)
595 : NumTokens{NumTokens_},
596 TopK{TopK_},
597 M{M_},
598 N{N_},
599 K{K_},
600 StrideA{StrideA_},
601 StrideScaleA{StrideScaleA_},
602 StrideB{StrideB_},
603 StrideScaleB{StrideScaleB_},
604 StrideDs{StrideDs_},
605 StrideC{StrideC_},
606 KBatch{KBatch_},
609 KRead{CalculateKRead(K_, KBatch_)},
610 KPadded{CalculateKPadded(K_, KBatch_)},
611 AK0{CalculateAK0Padded(K_, KBatch_)},
612 BK0{CalculateBK0Padded(K_, KBatch_)},
615 {
616 }
617
618 __host__ void Print() const
619 {
620 std::cout << "problem {" << "NumTokens:" << NumTokens << ", " << "TopK:" << TopK << ", "
621 << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
622 << "SA:" << StrideA << ", " << "SScaleA:" << StrideScaleA << ", "
623 << "SB:" << StrideB << ", " << "SScaleB:" << StrideScaleB << ", "
624 << "SC:" << StrideC << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded
625 << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded << ", "
626 << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock
627 << ", " << "NBlock: " << NBlock << "}" << std::endl;
628 }
629
639 std::array<index_t, NumDTensor> StrideDs;
650 };
651
652 // Argument
654 {
655 __host__ Argument(const index_t* p_sorted_token_ids_,
656 const index_t* p_sorted_expert_ids_,
657 const index_t* p_max_token_id_,
658 const ADataType* p_a_grid_,
659 const AScaleDataType* p_a_scale_grid_,
660 const BDataType* p_b_grid_,
661 const BScaleDataType* p_b_scale_grid_,
662 std::array<const void*, NumDTensor> p_ds_grid_,
663 CDataType* p_c_grid_,
664 index_t NumTokens_,
665 index_t TopK_,
666 index_t M_,
667 index_t N_,
668 index_t K_,
669 index_t StrideA_,
670 index_t StrideScaleA_,
671 index_t StrideB_,
672 index_t StrideScaleB_,
673 std::array<index_t, NumDTensor> StrideDs_,
674 index_t StrideC_,
675 index_t k_batch_,
676 AElementwiseOperation a_element_op_,
677 BElementwiseOperation b_element_op_,
678 CElementwiseOperation c_element_op_)
679 : Problem{NumTokens_,
680 TopK_,
681 M_,
682 N_,
683 K_ / APackedSize,
684 StrideA_ / APackedSize,
685 StrideScaleA_,
686 StrideB_ / BPackedSize,
687 StrideScaleB_,
688 StrideDs_,
689 StrideC_,
690 k_batch_},
691 p_sorted_token_ids{p_sorted_token_ids_},
692 p_sorted_expert_ids{p_sorted_expert_ids_},
693 p_max_token_id{p_max_token_id_},
694 p_a_grid{p_a_grid_},
695 p_a_scale_grid{p_a_scale_grid_},
696 p_b_grid{p_b_grid_},
697 p_b_scale_grid{p_b_scale_grid_},
698 p_ds_grid{},
699 p_c_grid{p_c_grid_},
700 a_element_op{a_element_op_},
701 b_element_op{b_element_op_},
702 c_element_op{c_element_op_}
703 {
704
705 // populate pointer, desc for Ds
706 static_for<0, NumDTensor, 1>{}([&](auto i) {
707 using DDataType_ = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
708
709 // D pointer
710 p_ds_grid(i) = static_cast<const DDataType_*>(p_ds_grid_[i]);
711 });
712 }
713
717 const ADataType* p_a_grid;
718 const AScaleDataType* p_a_scale_grid;
719 const BDataType* p_b_grid;
720 const BScaleDataType* p_b_scale_grid;
722 CDataType* p_c_grid;
723
724 const AElementwiseOperation a_element_op;
725 const BElementwiseOperation b_element_op;
726 const CElementwiseOperation c_element_op;
727 };
728
730 {
731 __device__ SplitKBatchOffset(Argument& karg, index_t k_id)
732 {
734 {
735 a_k_split_offset = k_id * karg.KRead;
736 }
738 {
739 a_k_split_offset = k_id * karg.KRead * karg.StrideA;
740 }
741
743 {
744 b_k_split_offset = k_id * karg.KRead * karg.StrideB;
745 }
747 {
748 // KPack * NLane * KLane * K0 * N0
749 b_k_split_offset = k_id * karg.KRead;
750 }
751
752 // Calculate A scale offset
754 {
755 a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize);
756 }
758 {
760 k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA;
761 }
762
763 // Calculate B scale offset
765 {
767 k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB;
768 }
770 {
771 b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize);
772 }
773
774 if(k_id < karg.KBatch - 1)
775 {
776 karg.K = karg.KRead;
777 }
778 else
779 {
780 karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
781 }
782 }
783
788 };
789
790 __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
791 {
792 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
793 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
794 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
795
796 // A matrix in LDS memory, dst of blockwise copy
797 if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
798 {
802 }
803 // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
804 // in some cases.
806 {
807 constexpr auto a_lds_block_desc =
810
811 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
812 a_lds_block_desc,
818
819 return a_lds_block_desc_permuted;
820 }
821 else // ColumnMajor A
822 {
823 // kfold and mpair dimension is not always required.
824 // more dimension in merge_transform increase the difficulty of generating immarg offset
825 // for compiler.
826 constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
827 constexpr auto M1 = MPerBlock / M0;
828
829 constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
830 constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
831 constexpr auto KThreadRead = WaveSize / MPerXdl;
832 constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
833
834 constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
835 ? 1
836 : 128 / (AK1Number * M0 * sizeof(ADataType));
837 constexpr auto KThreadReadPerm =
838 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
839 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
840 : KThreadRead;
841
842 // 1<=mpair<=n0
843 constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
844 ? 1
845 : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
846 ? M0
847 : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
848
849 constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
853 Number<kfold * M0 / mpair>{},
855 AK1Number));
856
857 constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
858 a_lds_block_desc,
863 make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
870
871 constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
872 a_lds_block_desc_permuted,
881 Sequence<1>{},
882 Sequence<2>{},
883 Sequence<3>{},
884 Sequence<4>{},
885 Sequence<5>{}),
887 Sequence<2>{},
890 Sequence<6>{},
891 Sequence<7>{}));
892
893 constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
894 a_lds_block_desc_unmerged,
897 Number<KThreadWrite / kfold / KThreadReadPerm>{},
905
906 return a_lds_block_desc_ak0_m_ak1;
907 }
908 }
909
910 __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
911 {
912 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
913 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
914 constexpr index_t WaveSize = BlockSize / (MWave * NWave);
915
916 // B matrix in LDS memory, dst of blockwise copy
917 if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
918 {
922 }
924 {
925 // NLdsLayer * K0 as logical Bank
926 constexpr auto b_lds_block_desc =
929
930 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
931 b_lds_block_desc,
937
938 return b_lds_block_desc_permuted;
939 }
940 else // RowMajor B
941 {
942 constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
943 constexpr auto N1 = NPerBlock / N0;
944
945 constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
946 constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
947 constexpr auto KThreadRead = WaveSize / NPerXdl;
948 constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
949
950 constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
951 ? 1
952 : 128 / (BK1Number * N0 * sizeof(BDataType));
953 constexpr auto KThreadReadPerm =
954 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
955 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
956 : KThreadRead;
957
958 // 1<=npair<=n0
959 constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
960 ? 1
961 : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
962 ? N0
963 : 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
964
965 constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
969 Number<kfold * N0 / npair>{},
971 BK1Number));
972
973 constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
974 b_lds_block_desc,
979 make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
986
987 constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
988 b_lds_block_desc_permuted,
997 Sequence<1>{},
998 Sequence<2>{},
999 Sequence<3>{},
1000 Sequence<4>{},
1001 Sequence<5>{}),
1003 Sequence<2>{},
1006 Sequence<6>{},
1007 Sequence<7>{}));
1008
1009 constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
1010 b_lds_block_desc_unmerged,
1013 Number<KThreadWrite / kfold / KThreadReadPerm>{},
1014 Number<kfold>{},
1021
1022 return b_lds_block_desc_bk0_n_bk1;
1023 }
1024 }
1025
1027 {
1028 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1029 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1030
1031 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1033 make_tuple(I1,
1035 I1,
1037
1038 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1039 }
1040
1043 BlkGemmPipelineVer,
1044 BlkGemmPipeSched,
1045 BlockSize,
1046 ScaleBlockSize,
1047 ADataType,
1048 AScaleDataType,
1049 BDataType,
1050 BScaleDataType,
1051 ComputeTypeA,
1052 AccDataType,
1059 ABlockTransferSrcScalarPerVector,
1060 BBlockTransferSrcScalarPerVector,
1061 MPerBlock,
1062 NPerBlock,
1063 KPerBlock,
1064 MPerXdl,
1065 NPerXdl,
1066 MXdlPerWave,
1067 NXdlPerWave,
1068 KPack,
1069 IsInputGemm>())>;
1070
1071 __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1072 {
1073 // LDS allocation for A and B: be careful of alignment
1074 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1075 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1076
1077 // lds max alignment
1078 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1079
1080 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1081 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1082
1083 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1084 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1085
1086 // LDS allocation for C shuffle in LDS
1087 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1089
1090 constexpr auto c_block_size =
1091 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1092
1093 if constexpr(IsInputGemm)
1094 {
1095 return math::max((a_block_space_size_aligned * sizeof(ADataType) +
1096 b_block_space_size_aligned * sizeof(BDataType)) *
1097 2,
1098 c_block_size * sizeof(CShuffleDataType));
1099 }
1100 else
1101 {
1102 return math::max((a_block_space_size_aligned * sizeof(ADataType) +
1103 b_block_space_size_aligned * sizeof(BDataType)),
1104 c_block_size * sizeof(CShuffleDataType));
1105 }
1106 }
1107
1109
1110 // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1111 __host__ static constexpr bool CheckValidity(const Argument& karg)
1112 {
1113 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1114 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1115 "Invalid tuning param!");
1116
1117 static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0,
1118 "KPerBlock should be multiple of ScaleBlockSize");
1119
1120 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1125 {
1126 if(!(karg.M % MPerBlock == 0))
1127 {
1128 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1129 {
1130 std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1131 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1132 << std::endl;
1133 }
1134 return false;
1135 }
1136 }
1137
1138 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1143 {
1144 if(!(karg.N % NPerBlock == 0))
1145 {
1146 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1147 {
1148 std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1149 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1150 << std::endl;
1151 }
1152 return false;
1153 }
1154 }
1155
1156 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1160 {
1161 auto K_t = karg.KBatch * KPerBlock;
1162 if(!(karg.K % K_t == 0))
1163 {
1164 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1165 {
1166 std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1167 << karg.K << " " << __FILE__ << ":" << __LINE__
1168 << ", in function: " << __func__ << std::endl;
1169 }
1170 return false;
1171 }
1172 }
1173 else
1174 {
1175 constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1176 auto K_t = karg.KBatch * KReadVec;
1177 auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1178 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1179 {
1180 return false;
1181 }
1182 }
1183
1185 {
1186 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1187 {
1188 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1189 {
1190 std::cout << "Arg K (" << karg.K
1191 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1192 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1193 << __LINE__ << ", in function: " << __func__ << std::endl;
1194 }
1195 return false;
1196 }
1197 }
1198 else
1199 {
1200 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1201 {
1202 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1203 {
1204 std::cout << "Arg M (" << karg.M
1205 << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1206 << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1207 << __LINE__ << ", in function: " << __func__ << std::endl;
1208 }
1209 return false;
1210 }
1211 }
1212
1214 {
1215 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1216 {
1217 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1218 {
1219 std::cout << "Arg N (" << karg.N
1220 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1221 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1222 << __LINE__ << ", in function: " << __func__ << std::endl;
1223 }
1224 return false;
1225 }
1226 }
1227 else
1228 {
1229 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1230 {
1231 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1232 {
1233 std::cout << "Arg K (" << karg.K
1234 << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1235 << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1236 << __LINE__ << ", in function: " << __func__ << std::endl;
1237 }
1238 return false;
1239 }
1240 }
1241
1243 {
1245 {
1246 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1247 {
1248 std::cout << "Arg N (" << karg.N
1249 << ") value is not a multiple of "
1250 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1252 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1253 << std::endl;
1254 }
1255 return false;
1256 }
1257 }
1258 else
1259 {
1261 {
1262 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1263 {
1264 std::cout << "Arg M (" << karg.M
1265 << ") value is not a multiple of "
1266 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1268 << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1269 << std::endl;
1270
1271 return false;
1272 }
1273 }
1274 }
1275
1276 // check gridwise gemm pipeline
1277#if 0
1278 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1279
1280 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1281 {
1282 return false;
1283 }
1284#endif
1285 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1286 return true;
1287 }
1288
1289 __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1290 {
1291 const index_t num_loop = K / KPerBlock;
1292
1293 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1294 }
1295
1296 __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1297 {
1298 const index_t num_loop = K / KPerBlock;
1299
1300 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1301 }
1302
1303 template <typename CGridDesc>
1304 __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1305 const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1306 {
1307 const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1308 c_grid_desc_m_n,
1313
1314 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1315 }
1316
1317 // return block_id to C matrix tile idx (m0, n0) mapping
1318 // if arch = gfx942
1319 // using Block2CTileMapDefault = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock,
1320 // NPerBlock>;
1321
1323 static constexpr index_t scale_pack_size_a = sizeof(AScaleDataType) / sizeof(mx_scale_t);
1324 static constexpr index_t scale_pack_size_b = sizeof(BScaleDataType) / sizeof(mx_scale_t);
1325 static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
1326 "A scale pack data type too large!");
1327 static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
1328 "B scale pack data type too large!");
1329
1330 template <bool HasMainKBlockLoop,
1331 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1332 TailNumber TailNum = TailNumber::Odd>
1333 __device__ static void Run(const index_t* p_sorted_token_ids,
1334 const index_t* p_sorted_expert_ids,
1335 const index_t* p_max_token_id,
1336 const ADataType* p_a_grid,
1337 const AScaleDataType* p_a_scale_grid,
1338 const BDataType* p_b_grid,
1339 const BScaleDataType* p_b_scale_grid,
1340 DsGridPointer& p_ds_grid,
1341 CDataType* p_c_grid,
1342 void* p_shared,
1343 const Problem& problem,
1344 AElementwiseOperation a_element_op,
1345 BElementwiseOperation b_element_op,
1346 CElementwiseOperation c_element_op)
1347 {
1348 ignore = b_element_op;
1349 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1350 IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
1351 problem.MPadded,
1352 problem.K,
1353 problem.KPadded,
1354 problem.StrideA,
1355 problem.AK0);
1356 const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1357 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1358 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1359 IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
1360 problem.MPadded,
1361 problem.N,
1362 problem.NPadded,
1363 problem.StrideC);
1364
1365 const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
1366 make_tuple(problem.M / (MXdlPack * MPerXdl),
1367 math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
1368 (KXdlPack * 64 / MPerXdl),
1370
1371 const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
1372 make_tuple(problem.N / (NXdlPack * NPerXdl),
1373 math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
1374 (KXdlPack * 64 / NPerXdl),
1376
1377 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1379 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1380
1381 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1382 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
1383 if(expert_block_id * MPerBlock >= max_token_id)
1384 return;
1385 const index_t expert_id =
1386 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1387
1388 const auto block_mn = [&]() -> std::pair<int, int> {
1389 if constexpr(NSwizzle)
1390 {
1391 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1392 const index_t prefix_block = ecnt_prefix * problem.NBlock;
1393 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1394 const index_t expert_swizzle =
1395 ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
1396 const index_t bid_new = blockIdx.x - prefix_block;
1397 const index_t nid = __builtin_amdgcn_readfirstlane(
1398 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1399 const index_t mid =
1400 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1401 return {nid, mid};
1402 }
1403 else
1404 {
1405 return {blockIdx.x, blockIdx.y};
1406 }
1407 }();
1408
1409 const index_t block_n_id = block_mn.first;
1410 const index_t block_m_id = block_mn.second;
1411 const index_t token0 =
1412 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1413
1414 // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1415 constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
1416 constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
1417 constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
1418 constexpr auto AKThreads = AK0Threads * AK1Threads;
1419 constexpr auto AMRepeats = MPerBlock / AMThreads;
1420 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1421
1422 if(token_pos >= max_token_id || token0 >= problem.NumTokens)
1423 return;
1425 static_for<0, AMRepeats, 1>{}([&](auto m0) {
1426 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1427 index_t token_offset = fused_token & 0xffffff;
1428 if constexpr(!IsInputGemm)
1429 {
1430 token_offset = token_offset * problem.TopK + (fused_token >> 24);
1431 }
1432 gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
1433 });
1434
1435 const index_t expert_stride =
1436 __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
1437 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1438 problem.N * (IsInputGemm ? 2 : 1) *
1439 math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
1440
1441 // N0, K0, Blocksize*KPack
1442 const index_t n_block_data_idx_on_grid =
1443 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1444
1445 // Gride buffer creation
1446 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1447 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1448 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1449 p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1450
1451 // A, B scale buffer
1452 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1453 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1454 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1455 p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
1456 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1457
1458 // lds max alignment
1459 constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1460
1461 // A matrix in LDS memory, dst of blockwise copy
1462 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1463
1464 // B matrix in LDS memory, dst of blockwise copy
1465 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1466
1467 // A matrix blockwise copy
1468 auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
1470 AElementwiseOperation,
1474 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1475 ABlockTransferThreadClusterArrangeOrder,
1476 ADataType,
1477 ADataType,
1478 decltype(a_grid_desc_ak0_m_ak1),
1479 decltype(a_block_desc_ak0_m_ak1),
1480 ABlockTransferSrcAccessOrder,
1482 ABlockTransferSrcVectorDim,
1483 2,
1484 ABlockTransferSrcScalarPerVector,
1485 ABlockTransferDstScalarPerVector_AK1,
1486 1,
1487 1,
1488 AThreadTransferSrcResetCoordinateAfterRun,
1489 true,
1490 IndexType,
1491 1,
1492 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1493 make_multi_index(0, 0, 0),
1494 a_element_op,
1495 a_block_desc_ak0_m_ak1,
1496 make_multi_index(0, 0, 0),
1498 gather_offsets);
1499
1500 // B matrix blockwise copy
1501 auto b_blockwise_copy =
1503 BElementwiseOperation,
1507 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1508 BBlockTransferThreadClusterArrangeOrder,
1509 BDataType,
1510 BDataType,
1511 decltype(b_grid_desc_bk0_n_bk1),
1512 decltype(b_block_desc_bk0_n_bk1),
1513 BBlockTransferSrcAccessOrder,
1515 BBlockTransferSrcVectorDim,
1516 2,
1517 BBlockTransferSrcScalarPerVector,
1518 BBlockTransferDstScalarPerVector_BK1,
1519 1,
1520 1,
1521 BThreadTransferSrcResetCoordinateAfterRun,
1522 true,
1523 BlockwiseGemmPipe::GlobalBufferNum>(
1524 b_grid_desc_bk0_n_bk1,
1525 make_multi_index(0, n_block_data_idx_on_grid, 0),
1526 b_element_op,
1527 b_block_desc_bk0_n_bk1,
1528 make_multi_index(0, 0, 0),
1530
1531 // LDS allocation for A and B: be careful of alignment
1532 constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1533 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1534
1535 // Cast after lds
1537 static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1538
1540 reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1541 a_block_space_size_aligned * sizeof(ADataType)),
1542 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1543
1544 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1545 constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1546
1547 // Blockwise GEMM pipeline
1548 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1549 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1550 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1551 decltype(c_thread_buf) c_thread_buf_up;
1552
1554 float,
1555 c_thread_buf.num_of_v_,
1556 c_thread_buf.s_per_v,
1557 true>
1558 c_thread_buf_fp32;
1559
1560 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1561 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1562 KPerBlock);
1563
1564 // a and b scale processing
1565 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1566 const auto waveId_m = wave_idx[I0];
1567 const auto waveId_n = wave_idx[I1];
1568
1569 auto thread_offset_shuffled =
1570 get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
1571
1572 auto a_thread_offset_m = waveId_m;
1573
1574 auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1575 AScaleDataType,
1576 AScaleDataType,
1577 decltype(a_scale_grid_desc_am_ak),
1578 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1579 Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
1580 Sequence<0, 1, 2>, // DimAccessOrder
1581 2, // SrcVectorDim
1582 KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
1583 1, // SrcScalarStrideInVector
1584 true>(a_scale_grid_desc_am_ak,
1585 make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
1586 0,
1587 thread_offset_shuffled / scale_pack_size_a));
1588
1589 // B scale load
1590 auto b_thread_offset_n = waveId_n;
1591
1592 auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
1593 BScaleDataType,
1594 BScaleDataType,
1595 decltype(b_scale_grid_desc_bn_ak),
1596 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1597 Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1598 Sequence<0, 1, 2>, // DimAccessOrder
1599 2, // SrcVectorDim
1600 KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
1601 1, // SrcScalarStrideInVector
1602 true>(b_scale_grid_desc_bn_ak,
1603 make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1604 0,
1605 thread_offset_shuffled / scale_pack_size_b));
1606
1607 if constexpr(IsInputGemm)
1608 {
1609 constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1610 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1611 auto b_block_buf_up = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1612 reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) +
1613 a_block_space_size_aligned * sizeof(ADataType) +
1614 b_block_space_size_aligned * sizeof(BDataType)),
1615 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1616
1617 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
1618 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1619 p_b_grid_up + expert_id * expert_stride,
1620 b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1621
1622 auto b_blockwise_copy_up =
1624 BElementwiseOperation,
1628 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1629 BBlockTransferThreadClusterArrangeOrder,
1630 BDataType,
1631 BDataType,
1632 decltype(b_grid_desc_bk0_n_bk1),
1633 decltype(b_block_desc_bk0_n_bk1),
1634 BBlockTransferSrcAccessOrder,
1636 BBlockTransferSrcVectorDim,
1637 2,
1638 BBlockTransferSrcScalarPerVector,
1639 BBlockTransferDstScalarPerVector_BK1,
1640 1,
1641 1,
1642 BThreadTransferSrcResetCoordinateAfterRun,
1643 true,
1644 BlockwiseGemmPipe::GlobalBufferNum>(
1645 b_grid_desc_bk0_n_bk1,
1646 make_multi_index(0, n_block_data_idx_on_grid, 0),
1647 b_element_op,
1648 b_block_desc_bk0_n_bk1,
1649 make_multi_index(0, 0, 0),
1651
1652 const BScaleDataType* p_b_scale_grid_up =
1653 p_b_scale_grid + expert_scale_stride / 2 / sizeof(BScaleDataType);
1654 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1655 p_b_scale_grid_up + expert_id * expert_scale_stride / sizeof(BScaleDataType),
1656 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1657
1658 auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
1659 BScaleDataType,
1660 BScaleDataType,
1661 decltype(b_scale_grid_desc_bn_ak),
1662 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1663 Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
1664 Sequence<0, 1, 2>, // DimAccessOrder
1665 2, // SrcVectorDim
1666 KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
1667 1, // SrcScalarStrideInVector
1668 true>(
1669 b_scale_grid_desc_bn_ak,
1670 make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
1671 0,
1672 thread_offset_shuffled / scale_pack_size_b));
1673
1674 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1675 // A
1676 a_grid_desc_ak0_m_ak1,
1677 a_block_desc_ak0_m_ak1,
1678 a_blockwise_copy,
1679 a_grid_buf,
1680 a_block_buf,
1681 a_block_slice_copy_step,
1682 // Gate and Up
1683 b_grid_desc_bk0_n_bk1,
1684 b_block_desc_bk0_n_bk1,
1685 b_blockwise_copy,
1686 b_blockwise_copy_up,
1687 b_grid_buf,
1688 b_grid_buf_up,
1689 b_block_buf,
1690 b_block_buf_up,
1691 b_block_slice_copy_step,
1692 // C
1693 c_thread_buf,
1694 c_thread_buf_up,
1695 // A scale
1696 a_scale_grid_desc_am_ak,
1697 a_scale_thread_copy,
1698 a_scale_grid_buf,
1699 // Gate and Up scale
1700 b_scale_grid_desc_bn_ak,
1701 b_scale_thread_copy,
1702 b_scale_thread_copy_up,
1703 b_scale_grid_buf,
1704 b_scale_grid_buf_up,
1705 num_k_block_main_loop);
1706 }
1707 else
1708 {
1709 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1710 a_grid_desc_ak0_m_ak1, // A
1711 a_block_desc_ak0_m_ak1,
1712 a_blockwise_copy,
1713 a_grid_buf,
1714 a_block_buf,
1715 a_block_slice_copy_step,
1716 b_grid_desc_bk0_n_bk1, // B
1717 b_block_desc_bk0_n_bk1,
1718 b_blockwise_copy,
1719 b_grid_buf,
1720 b_block_buf,
1721 b_block_slice_copy_step,
1722 c_thread_buf, // C
1723 a_scale_grid_desc_am_ak, // A scale
1724 a_scale_thread_copy,
1725 a_scale_grid_buf,
1726 b_scale_grid_desc_bn_ak, // B scale
1727 b_scale_thread_copy,
1728 b_scale_grid_buf,
1729 num_k_block_main_loop);
1730 }
1731
1732 // shuffle C and write out
1733 {
1734 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1735 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1736 "wrong!");
1737 static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 &&
1738 CShuffleNXdlPerWavePerShuffle % NXdlPack == 0,
1739 "wrong!");
1740
1741 constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1742 constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1743
1744 // TODO: hacky, fix it!
1745 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1746 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1747
1748 // TODO: hacky, fix it!
1749 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1750 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1751 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1752
1753 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1754 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1755 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1756 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1757 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1758 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1759 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1760 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1761 constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8);
1762 constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9);
1763
1764 // mul scales
1765 static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
1766 static_assert(M5 == 4);
1767 const index_t m1 = get_warp_local_1d_id() / NWave; // Mwave id
1768 const index_t m4 = threadIdx.x % get_warp_size() / MPerXdl;
1769
1770 vector_type<float, 4> topk_weights; // for gemm2 only
1771 static_for<0, NXdlPerWave / NXdlPack, 1>{}([&](auto n0) {
1772 static_for<0, NXdlPack, 1>{}([&](auto inxdl) { // NXdlPack
1773 static_for<0, MXdlPerWave / MXdlPack, 1>{}([&](auto m0) { // MXDLPerWave
1774 static_for<0, MXdlPack, 1>{}([&](auto imxdl) { // MXdlPack
1775 static_for<0, M3, 1>{}([&](auto m3) { // m_inst_num_groups_per_blk
1776 const index_t m_pos = block_m_id * MPerBlock +
1777 m0 * M2 * M1 * M3 * M4 * M5 +
1778 m1 * M2 * M3 * M4 * M5 +
1779 imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
1780
1781 if constexpr(MulRoutedWeight)
1782 {
1783 topk_weights =
1785 p_ds_grid[I2] + m_pos);
1786 }
1787 static_for<0, M5, 1>{}([&](auto m5) { // m_inst_group_size
1788 constexpr index_t c_offset =
1789 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1790 make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
1791 constexpr auto cidx = Number<c_offset>{};
1792
1793 if constexpr(IsInputGemm) // gu fusion
1794 {
1795 if constexpr(ActivationOperation ==
1796 Activation::silu_and_mul)
1797 {
1798 float gate = c_thread_buf[cidx];
1799 float up = c_thread_buf_up[cidx];
1800 if constexpr(MulRoutedWeight)
1801 {
1802 gate = gate * topk_weights.AsType<float>()[m5];
1803 up = up * topk_weights.AsType<float>()[m5];
1804 }
1806 c_thread_buf_fp32(cidx) = gate * up;
1807 }
1808 else if(ActivationOperation == Activation::gelu_and_mul)
1809 {
1810 float gate = c_thread_buf[cidx];
1811 float up = c_thread_buf_up[cidx];
1812 if constexpr(MulRoutedWeight)
1813 {
1814 gate = gate * topk_weights.AsType<float>()[m5];
1815 up = up * topk_weights.AsType<float>()[m5];
1816 }
1818 c_thread_buf_fp32(cidx) = gate * up;
1819
1820 /*float gate = c_thread_buf[cidx];
1821 float up = c_thread_buf_up[cidx];
1822 if constexpr(MulRoutedWeight)
1823 {
1824 gate = gate * topk_weights.AsType<float>()[m5];
1825 //up = up * topk_weights.AsType<float>()[m5];
1826 }
1827 tensor_operation::element_wise::Gelu{}(gate, gate);
1828 c_thread_buf_fp32(cidx) = up;*/
1829 }
1830 }
1831 else
1832 {
1833 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1834 if constexpr(MulRoutedWeight)
1835 {
1836 c_thread_buf_fp32(cidx) =
1837 topk_weights.AsType<float>()[m5] *
1838 c_thread_buf_fp32[cidx];
1839 }
1840 }
1841 });
1842 });
1843 });
1844 });
1845 });
1846 });
1847
1848 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1850
1851 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1852 static_cast<CShuffleDataType*>(p_shared),
1853 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1854
1855 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1856 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1857 make_tuple(
1861 // per shuffle
1862 M1, // M1 = MWave
1863 M2, // M2 = MXdlPack
1864 M3, // M3 * M4 * M5 = MPerXdl
1865 M4,
1866 M5)),
1870 // per shuffle
1871 N1, // N1 = NWave
1872 N2, // N2 = NXdlPack
1873 N3))), // N3 = NPerXdl
1877 Sequence<>{},
1879
1880 // calculate origin of thread output tensor on global memory
1881 // blockwise GEMM c matrix starting index
1882 const auto c_thread_mtx_on_block =
1883 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1884
1885 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1886 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1887
1888 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1890 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))),
1893
1894 const auto m_thread_data_on_block_idx =
1895 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1896 make_multi_index(m_thread_data_on_block));
1897
1898 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1900 make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))),
1903
1904 const auto n_thread_data_on_block_idx =
1905 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1906 make_multi_index(n_thread_data_on_block));
1907
1908 // shuffle: threadwise copy C from VGPR to LDS
1909 auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1910 AccDataType,
1911 CShuffleDataType,
1912 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1913 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1915 Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
1916 CShuffleNXdlPerWavePerShuffle / NXdlPack,
1917 I1,
1918 I1,
1919 M2,
1920 N2,
1921 M3,
1922 I1,
1923 M5,
1924 I1>,
1926 9,
1927 1,
1929 1,
1930 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1932 0,
1933 m_thread_data_on_block_idx[I1],
1934 n_thread_data_on_block_idx[I1],
1935 m_thread_data_on_block_idx[I2],
1936 n_thread_data_on_block_idx[I2],
1937 m_thread_data_on_block_idx[I3],
1938 m_thread_data_on_block_idx[I4],
1939 m_thread_data_on_block_idx[I5],
1940 n_thread_data_on_block_idx[I3]),
1942
1943 using EDataType = CDataType;
1944
1945 const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
1946 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
1947
1948 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1950 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
1951
1952 const auto ds_grid_buf = generate_tuple(
1953 [&](auto i) {
1955 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1956 },
1958
1959 // tuple of reference to C/Ds tensor descriptors
1960 const auto c_ds_desc_refs = concat_tuple_of_reference(
1961 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1962 generate_tie([&](auto i) -> const auto& // return type should be reference
1963 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1965
1966 // tuple of reference to C/Ds tensor descriptors
1967 const auto c_ds_buf_refs = concat_tuple_of_reference(
1968 tie(c_shuffle_block_buf),
1969 generate_tie([&](auto i) -> const auto& // return type should be reference
1970 { return ds_grid_buf[i]; },
1972
1973 // tuple of starting index of C/Ds blockwise copy
1974 const auto idx_c_ds_block_begin =
1977 [&](auto) {
1978 return make_multi_index(block_m_id, 0, block_n_id, 0);
1979 // return make_multi_index(block_work_idx[I0], 0,
1980 // block_work_idx[I1], 0);
1981 },
1983
1984 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1985 c_grid_desc_mblock_mperblock_nblock_nperblock;
1986
1987 using CDEBlockTransferCluster =
1988 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1989 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1990 constexpr index_t scatter_weight_idx = 3; // hack fix felix
1991 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
1993 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
1995 decltype(c_ds_desc_refs),
1996 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1997 CElementwiseOperation,
1998 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
1999 // Sequence support
2000 // arbitray type
2001 Sequence<1,
2002 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2003 1,
2004 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2005 CDEBlockTransferCluster,
2006 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2007 Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2008 Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2009 3, // index_t SrcVectorDim,
2010 3, // index_t DstVectorDim,
2011 CDEShuffleBlockTransferScalarPerVectors,
2016 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2017 Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2018 IndexType,
2019 1, // ScatterDim
2020 true, // OutputScatter: false, only use scatter weights
2021 scatter_weight_idx // ScatterWeightIdx: ascale
2022 >{c_ds_desc_refs,
2023 idx_c_ds_block_begin,
2024 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2025 make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2026 c_element_op};
2027
2029 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2030
2031 constexpr auto sfc_c_vgpr =
2032 SpaceFillingCurve<Sequence<MXdlPerWave / MXdlPack,
2033 NXdlPerWave / NXdlPack,
2034 1,
2035 1,
2036 MXdlPack,
2037 NXdlPack,
2038 M2,
2039 1,
2040 M4,
2041 1>,
2043 Sequence<CShuffleMXdlPerWavePerShuffle / MXdlPack,
2044 CShuffleNXdlPerWavePerShuffle / NXdlPack,
2045 1,
2046 1,
2047 MXdlPack,
2048 NXdlPack,
2049 M2,
2050 1,
2051 M4,
2052 1>>{};
2053
2054 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2055
2056 // space filling curve for shuffled blockwise C/D/E
2057 constexpr auto sfc_cde_block =
2060 Sequence<1,
2061 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2062 1,
2063 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2064
2065 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2066 constexpr auto EMThreads =
2067 CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2068 constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2069 constexpr auto ENThreads =
2070 CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2071 static_for<0, num_access, 1>{}([&](auto access_id) {
2072 // make sure it's safe to write to LDS
2074
2075 auto dstidx = sfc_cde_block.GetIndex(access_id);
2076 const index_t c_token_pos =
2077 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2078 static_for<0, EMRepeats, 1>{}([&](auto m0) {
2079 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2080 IndexType token_offset = fused_token & 0xffffff;
2081 if constexpr(IsInputGemm)
2082 {
2083 token_offset = token_offset * problem.TopK + (fused_token >> 24);
2084 }
2085 scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2086 });
2087
2089
2090 // each thread write its data from VGPR to LDS
2091 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2092 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2093 c_thread_buf_fp32,
2094 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2095 c_shuffle_block_buf);
2096
2097 // make sure it's safe to read from LDS
2099
2100 // each block copy its data from LDS to global
2101 cde_block_copy_lds_and_global.Run(
2102 c_ds_desc_refs,
2103 c_ds_buf_refs,
2104 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2105 tie(c_grid_buf),
2106 scatter_offsets);
2107
2108 if constexpr(access_id < num_access - 1)
2109 {
2110 constexpr auto cde_lds_and_global_step =
2111 sfc_cde_block.GetForwardStep(access_id);
2112
2113 // move on Ds
2114 static_for<0, NumDTensor, 1>{}([&](auto i) {
2115 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2116 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2117 });
2118
2119 // move on E
2120 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2121 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2122 I0,
2123 cde_lds_and_global_step);
2124 }
2125 });
2126 }
2127 }
2128
2129#if 0
2130 template <bool HasMainKBlockLoop,
2131 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2132 TailNumber TailNum = TailNumber::Odd>
2133 __device__ static void Run_2Lds(const index_t* p_sorted_token_ids,
2134 const index_t* p_sorted_expert_ids,
2135 const index_t* p_max_token_id,
2136 const ADataType* p_a_grid,
2137 const AScaleDataType* p_a_scale_grid,
2138 const BDataType* p_b_grid,
2139 const BScaleDataType* p_b_scale_grid,
2140 DsGridPointer& p_ds_grid,
2141 CDataType* p_c_grid,
2142 void* p_shared,
2143 void* p_shared1,
2144 const Problem& problem,
2145 AElementwiseOperation a_element_op,
2146 BElementwiseOperation b_element_op,
2147 CElementwiseOperation c_element_op)
2148 {
2149 ignore = b_element_op;
2150 const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2151 IsInputGemm ? problem.NumTokens : problem.NumTokens * problem.TopK,
2152 problem.MPadded,
2153 problem.K,
2154 problem.KPadded,
2155 problem.StrideA,
2156 problem.AK0);
2157 const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
2158 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2159 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2160 IsInputGemm ? problem.NumTokens * problem.TopK : problem.NumTokens,
2161 problem.MPadded,
2162 problem.N,
2163 problem.NPadded,
2164 problem.StrideC);
2165
2166 const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(
2167 make_tuple((IsInputGemm ? problem.NumTokens : problem.M) / (MXdlPack * MPerXdl),
2168 math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) /
2169 (KXdlPack * 64 / MPerXdl),
2171
2172 const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(
2173 make_tuple(problem.N / (NXdlPack * NPerXdl),
2174 math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) /
2175 (KXdlPack * 64 / NPerXdl),
2177
2178 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2180 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2181 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2182 // static_assert(NSwizzle == false, "to do fix: need another pr in sorting merged");
2183 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
2184 if(expert_block_id * MPerBlock >= max_token_id)
2185 return;
2186 const index_t expert_id =
2187 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2188 const auto block_mn = [&]() -> std::pair<int, int> {
2189 if constexpr(NSwizzle)
2190 {
2191 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2192 const index_t prefix_block = ecnt_prefix * problem.NBlock;
2193 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2194 const index_t expert_swizzle =
2195 ecnt > 0 ? ecnt : 1; // p_max_token_id[expert_id + 1]; // 2
2196 const index_t bid_new = blockIdx.x - prefix_block;
2197 const index_t nid = __builtin_amdgcn_readfirstlane(
2198 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2199 const index_t mid =
2200 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2201 return {nid, mid};
2202 }
2203 else
2204 {
2205 return {blockIdx.x, blockIdx.y};
2206 }
2207 }();
2208
2209 const index_t block_n_id = block_mn.first;
2210 const index_t block_m_id = block_mn.second;
2211 const index_t token0 =
2212 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2213
2214 // constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2215 constexpr auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
2216 constexpr auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
2217 constexpr auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I2);
2218 constexpr auto AKThreads = AK0Threads * AK1Threads;
2219 constexpr auto AMRepeats = MPerBlock / AMThreads;
2220 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
2221
2222 if(token_pos >= max_token_id || token0 >= problem.NumTokens)
2223 return;
2225 static_for<0, AMRepeats, 1>{}([&](auto m0) {
2226 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
2227 index_t token_offset = fused_token & 0xffffff;
2228 if constexpr(!IsInputGemm)
2229 {
2230 token_offset = token_offset * problem.TopK + (fused_token >> 24);
2231 }
2232 gather_offsets(m0) = static_cast<IndexType>(token_offset) * problem.K;
2233 });
2234
2235 const index_t expert_stride =
2236 __builtin_amdgcn_readfirstlane(problem.N * problem.K * (IsInputGemm ? 2 : 1));
2237 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2238 problem.N * math::integer_divide_ceil(problem.K, ScaleBlockSize / BPackedSize));
2239
2240 // N0, K0, Blocksize*KPack
2241 const index_t n_block_data_idx_on_grid =
2242 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
2243
2244 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2245 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2246
2247 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2248 p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2249
2250 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2251 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2252 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2253 p_b_scale_grid + (expert_id * expert_scale_stride) / sizeof(BScaleDataType),
2254 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2255
2256 // A matrix in LDS memory, dst of blockwise copy
2257 constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
2258
2259 // B matrix in LDS memory, dst of blockwise copy
2260 // dummy
2261 constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
2262 // A matrix blockwise copy
2263 auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v4r1_gather<
2265 AElementwiseOperation,
2266 ck::tensor_operation::element_wise::PassThrough,
2268 Sequence<AK0Number, MPerBlock, AK1Number>,
2269 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2270 ABlockTransferThreadClusterArrangeOrder,
2271 ADataType,
2272 LDSTypeA,
2273 decltype(a_grid_desc_ak0_m_ak1),
2274 decltype(a_block_desc_ak0_m_ak1),
2275 ABlockTransferSrcAccessOrder,
2276 Sequence<0, 1, 2>,
2277 ABlockTransferSrcVectorDim,
2278 2,
2279 ABlockTransferSrcScalarPerVector,
2280 ABlockTransferDstScalarPerVector_AK1,
2281 1,
2282 1,
2283 AThreadTransferSrcResetCoordinateAfterRun,
2284 true,
2285 IndexType,
2286 1,
2287 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
2288 make_multi_index(0, 0, 0),
2289 a_element_op,
2290 a_block_desc_ak0_m_ak1,
2291 make_multi_index(0, 0, 0),
2292 ck::tensor_operation::element_wise::PassThrough{},
2293 gather_offsets);
2294
2295 // Thread-wise copy
2296 // K0 -> N0/NWave -> NWave -> KLane -> NLane -> KPack
2298 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2300 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2301 auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
2302
2303 auto b_blockwise_copy =
2304 ThreadwiseTensorSliceTransfer_v2<BDataType,
2305 BDataType,
2306 decltype(b_grid_desc_bpreshuffled),
2307 decltype(b_block_desc_bk0_n_bk1),
2308 Sequence<Number<NXdlPerWave / NXdlPack>{},
2309 I1,
2313 Sequence<1, 2, 0, 3, 4>,
2314 4,
2315 BBlockTransferSrcScalarPerVector,
2316 BThreadTransferSrcResetCoordinateAfterRun,
2317 true>(
2318 b_grid_desc_bpreshuffled,
2319 make_multi_index(n_block_data_idx_on_grid,
2320 get_warp_local_1d_id() % NWave,
2321 0,
2322 0,
2323 KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2324
2325 // LDS allocation for A and B: be careful of alignment
2326 // Cast after lds
2327 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2328 static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2329 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2330 static_cast<ADataType*>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2331 auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
2332
2333 constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
2334 constexpr auto b_block_slice_copy_step = make_multi_index(0, 0, 0, KRepeat, 0);
2335
2336 // Blockwise GEMM pipeline
2337 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2338 auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
2339 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2340 decltype(c_thread_buf) c_thread_buf_up;
2341
2342 StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
2343 float,
2344 c_thread_buf.num_of_v_,
2345 c_thread_buf.s_per_v,
2346 true>
2347 c_thread_buf_fp32;
2348
2349 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2350 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
2351 KPerBlock);
2352
2353 // a and b scale processing
2354 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2355 const auto waveId_m = wave_idx[I0];
2356 const auto waveId_n = wave_idx[I1];
2357
2358 auto thread_offset_shuffled =
2359 get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack;
2360
2361 auto a_thread_offset_m = waveId_m;
2362
2363 // get each thread's offset int the scale tensor
2364 const index_t token_scale_pos = block_m_id * MPerBlock;
2365 if(token_scale_pos >= max_token_id || token0 >= problem.NumTokens)
2366 return;
2367
2368 auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2369 AScaleDataType,
2370 AScaleDataType,
2371 decltype(a_scale_grid_desc_am_ak),
2372 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2373 Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths
2374 Sequence<0, 1, 2>, // DimAccessOrder
2375 2, // SrcVectorDim
2376 KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector
2377 1, // SrcScalarStrideInVector
2378 true>(a_scale_grid_desc_am_ak,
2379 make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m,
2380 0,
2381 thread_offset_shuffled / scale_pack_size_a));
2382
2383 // B scale load
2384 auto b_thread_offset_n = waveId_n;
2385
2386 auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
2387 BScaleDataType,
2388 BScaleDataType,
2389 decltype(b_scale_grid_desc_bn_ak),
2390 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2391 Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2392 Sequence<0, 1, 2>, // DimAccessOrder
2393 2, // SrcVectorDim
2394 KXdlPack * NXdlPack / scale_pack_size_b, // SrcScalarPerVector
2395 1, // SrcScalarStrideInVector
2396 true>(b_scale_grid_desc_bn_ak,
2397 make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2398 0,
2399 thread_offset_shuffled / scale_pack_size_b));
2400
2401 if constexpr(IsInputGemm)
2402 {
2403 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / BPackedSize;
2404 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2405 p_b_grid_up + expert_id * expert_stride / BPackedSize,
2406 b_grid_desc_bpreshuffled.GetElementSpaceSize());
2407 auto b_blockwise_copy_up = ThreadwiseTensorSliceTransfer_v2<
2408 BDataType,
2409 BDataType,
2410 decltype(b_grid_desc_bpreshuffled),
2411 decltype(b_block_desc_bk0_n_bk1),
2412 Sequence<Number<NXdlPerWave>{}, I1, Number<KRepeat>{}, Number<BK1Value>{}>,
2413 Sequence<1, 2, 0, 3>,
2414 3,
2415 BBlockTransferSrcScalarPerVector,
2416 BThreadTransferSrcResetCoordinateAfterRun,
2417 true>(b_grid_desc_bpreshuffled,
2418 make_multi_index(n_block_data_idx_on_grid,
2419 get_warp_local_1d_id() % NWave,
2420 0,
2421 KPack / KGroup * (get_thread_local_1d_id() % WarpSize)));
2422 const BScaleDataType* p_b_scale_grid_up = p_b_scale_grid + expert_scale_stride / 2;
2423 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2424 p_b_scale_grid_up + expert_id * expert_scale_stride,
2425 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2426 auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
2427 BScaleDataType,
2428 BScaleDataType,
2429 decltype(b_scale_grid_desc_bn_ak),
2430 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2431 Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths
2432 Sequence<0, 1, 2>, // DimAccessOrder
2433 2, // SrcVectorDim
2434 KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector
2435 1, // SrcScalarStrideInVector
2436 true>(
2437 b_scale_grid_desc_bn_ak,
2438 make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n,
2439 0,
2440 thread_offset_shuffled / scale_pack_size_b));
2441
2442 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2443 a_grid_desc_ak0_m_ak1,
2444 a_block_desc_ak0_m_ak1,
2445 a_blockwise_copy,
2446 a_grid_buf,
2447 a_block_bufs,
2448 a_block_slice_copy_step,
2449 b_grid_desc_bpreshuffled,
2450 b_block_desc_bk0_n_bk1,
2451 b_blockwise_copy,
2452 b_blockwise_copy_up,
2453 b_grid_buf,
2454 b_grid_buf_up,
2455 b_block_bufs,
2456 b_block_slice_copy_step,
2457 c_thread_buf,
2458 c_thread_buf_up,
2459 a_scale_grid_desc_am_ak,
2460 a_scale_thread_copy,
2461 a_scale_grid_buf,
2462 b_scale_grid_desc_bn_ak,
2463 b_scale_thread_copy,
2464 b_scale_thread_copy_up,
2465 b_scale_grid_buf,
2466 b_scale_grid_buf_up,
2467 num_k_block_main_loop);
2468 }
2469 else
2470 {
2471 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2472 a_grid_desc_ak0_m_ak1,
2473 a_block_desc_ak0_m_ak1,
2474 a_blockwise_copy,
2475 a_grid_buf,
2476 a_block_bufs,
2477 a_block_slice_copy_step,
2478 b_grid_desc_bpreshuffled,
2479 b_block_desc_bk0_n_bk1,
2480 b_blockwise_copy,
2481 b_grid_buf,
2482 b_block_bufs,
2483 b_block_slice_copy_step,
2484 c_thread_buf,
2485 a_scale_grid_desc_am_ak,
2486 a_scale_thread_copy,
2487 a_scale_grid_buf,
2488 b_scale_grid_desc_bn_ak,
2489 b_scale_thread_copy,
2490 b_scale_grid_buf,
2491 num_k_block_main_loop);
2492 }
2493
2494 // shuffle C and write out
2495 {
2496 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2497 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2498 "wrong!");
2499
2500 // TODO: hacky, fix it!
2501 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2502 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2503
2504 // TODO: hacky, fix it!
2505 // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
2506 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2507 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
2508
2509 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
2510 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
2511 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
2512 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
2513 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
2514 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
2515 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
2516 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
2517
2518 // mul scales
2519
2520 static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
2521 static_assert(M4 == 4);
2522 const index_t m1 = get_warp_local_1d_id() / NWave;
2523 const index_t m3 = threadIdx.x % get_warp_size() / MPerXdl;
2524
2525 vector_type<float, 4> topk_weights; // for gemm2 only
2526 static_for<0, NXdlPerWave, 1>{}([&](auto n0) {
2527 static_for<0, MXdlPerWave, 1>{}([&](auto m0) { // MXDLPerWave
2528 static_for<0, M2, 1>{}([&](auto m2) { // m_inst_num_groups_per_blk
2529 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
2530 m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
2531 if constexpr(MulRoutedWeight)
2532 {
2534 p_ds_grid[I2] + m_pos);
2535 }
2536 static_for<0, M4, 1>{}([&](auto m4) { // m_inst_group_size
2537 constexpr index_t c_offset =
2538 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2539 make_tuple(m0 / MXdlPack,
2540 n0 / NXdlPack,
2541 m0 % MXdlPack,
2542 n0 % NXdlPack,
2543 m2 * M4 + m4));
2544 constexpr auto cidx = Number<c_offset>{};
2545
2546 if constexpr(IsInputGemm) // gu fusion
2547 {
2548 if constexpr(ActivationOperation == Activation::silu_and_mul)
2549 {
2550 float gate = c_thread_buf[cidx];
2551 float up = c_thread_buf_up[cidx];
2552 if constexpr(MulRoutedWeight)
2553 {
2554 gate = gate * topk_weights.AsType<float>()[m4];
2555 up = up * topk_weights.AsType<float>()[m4];
2556 }
2557 tensor_operation::element_wise::Silu{}(gate, gate);
2558 c_thread_buf_fp32(cidx) = gate * up;
2559 }
2560 else if(ActivationOperation == Activation::gelu_and_mul)
2561 {
2562 float gate = c_thread_buf[cidx];
2563 float up = c_thread_buf_up[cidx];
2564 if constexpr(MulRoutedWeight)
2565 {
2566 gate = gate * topk_weights.AsType<float>()[m4];
2567 up = up * topk_weights.AsType<float>()[m4];
2568 }
2569 tensor_operation::element_wise::Gelu{}(gate, gate);
2570 c_thread_buf_fp32(cidx) = gate * up;
2571 }
2572 }
2573 else
2574 {
2575 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2576 if constexpr(MulRoutedWeight)
2577 {
2578 c_thread_buf_fp32(cidx) =
2579 topk_weights.AsType<float>()[m4] * c_thread_buf_fp32[cidx];
2580 }
2581 }
2582 });
2583 });
2584 });
2585 });
2586
2587 constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2589
2590 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2591 static_cast<CShuffleDataType*>(p_shared),
2592 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2593
2594 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2595 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2598 Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per
2599 // shuffle
2600 M1, // M1 = MWave
2601 M2, // M2 * M3 * M4 = MPerXdl
2602 M3,
2603 M4)),
2606 Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per
2607 // shuffle
2608 N1, // N1 = NWave
2609 N2))), // N2 = NPerXdl
2610 make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
2611 make_tuple(
2612 Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{}));
2613
2614 // calculate origin of thread output tensor on global memory
2615 // blockwise GEMM c matrix starting index
2616 const auto c_thread_mtx_on_block =
2617 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2618
2619 const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2620 const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2621
2622 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2624 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
2625 make_tuple(Sequence<0, 1, 2, 3, 4>{}),
2626 make_tuple(Sequence<0>{}));
2627
2628 const auto m_thread_data_on_block_idx =
2629 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2630 make_multi_index(m_thread_data_on_block));
2631
2632 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2635 make_tuple(Sequence<0, 1, 2>{}),
2636 make_tuple(Sequence<0>{}));
2637
2638 const auto n_thread_data_on_block_idx =
2639 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2640 make_multi_index(n_thread_data_on_block));
2641
2642 // shuffle: threadwise copy C from VGPR to LDS
2643 auto c_thread_copy_vgpr_to_lds =
2644 ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
2645 CShuffleDataType,
2646 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2647 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2648 ck::tensor_operation::element_wise::PassThrough,
2649 Sequence<CShuffleMXdlPerWavePerShuffle,
2650 CShuffleNXdlPerWavePerShuffle,
2651 I1,
2652 I1,
2653 M2,
2654 I1,
2655 M4,
2656 I1>,
2657 Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
2658 7,
2659 1,
2661 1,
2662 true>{
2663 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2665 0,
2666 m_thread_data_on_block_idx[I1],
2667 n_thread_data_on_block_idx[I1],
2668 m_thread_data_on_block_idx[I2],
2669 m_thread_data_on_block_idx[I3],
2670 m_thread_data_on_block_idx[I4],
2671 n_thread_data_on_block_idx[I2]),
2672 ck::tensor_operation::element_wise::PassThrough{}};
2673
2674 using EDataType = CDataType;
2675
2676 const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
2677 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
2678
2679 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2681 ds_grid_desc_m_n, problem.MBlock, problem.NBlock);
2682
2683 const auto ds_grid_buf = generate_tuple(
2684 [&](auto i) {
2686 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2687 },
2689
2690 // tuple of reference to C/Ds tensor descriptors
2691 const auto c_ds_desc_refs = concat_tuple_of_reference(
2692 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2693 generate_tie([&](auto i) -> const auto& // return type should be reference
2694 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2696
2697 // tuple of reference to C/Ds tensor descriptors
2698 const auto c_ds_buf_refs = concat_tuple_of_reference(
2699 tie(c_shuffle_block_buf),
2700 generate_tie([&](auto i) -> const auto& // return type should be reference
2701 { return ds_grid_buf[i]; },
2703
2704 // tuple of starting index of C/Ds blockwise copy
2705 const auto idx_c_ds_block_begin =
2708 [&](auto) {
2709 return make_multi_index(block_m_id, 0, block_n_id, 0);
2710 // return make_multi_index(block_work_idx[I0], 0,
2711 // block_work_idx[I1], 0);
2712 },
2714
2715 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2716 c_grid_desc_mblock_mperblock_nblock_nperblock;
2717
2718 using CDEBlockTransferCluster =
2719 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2720 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2721 constexpr index_t scatter_weight_idx = 3; // hack fix felix
2722 auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
2724 decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
2725 Tuple<EDataType>,
2726 decltype(c_ds_desc_refs),
2727 decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2728 CElementwiseOperation,
2729 Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make
2730 // Sequence support
2731 // arbitray type
2732 Sequence<1,
2733 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2734 1,
2735 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2736 CDEBlockTransferCluster,
2737 Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2738 Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
2739 Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
2740 3, // index_t SrcVectorDim,
2741 3, // index_t DstVectorDim,
2742 CDEShuffleBlockTransferScalarPerVectors,
2745 Sequence<true>,
2747 false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
2748 Sequence<false>, // ThreadTransferDstResetCoordinateAfterRunFlags
2749 IndexType,
2750 1, // ScatterDim
2751 true, // OutputScatter: false, only use scatter weights
2752 scatter_weight_idx // ScatterWeightIdx: ascale
2753 >{c_ds_desc_refs,
2754 idx_c_ds_block_begin,
2755 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2756 make_tuple(make_multi_index(0, 0, block_n_id, 0)),
2757 c_element_op};
2758
2760 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2761 constexpr auto sfc_c_vgpr =
2762 SpaceFillingCurve<Sequence<MXdlPerWave, NXdlPerWave, 1, 1, M2, 1, M4, 1>,
2763 Sequence<0, 1, 2, 3, 4, 5, 6, 7>,
2764 Sequence<CShuffleMXdlPerWavePerShuffle,
2765 CShuffleNXdlPerWavePerShuffle,
2766 1,
2767 1,
2768 M2,
2769 1,
2770 M4,
2771 1>>{};
2772
2773 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2774
2775 // space filling curve for shuffled blockwise C/D/E
2776 constexpr auto sfc_cde_block =
2777 SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
2778 Sequence<0, 2, 1, 3>,
2779 Sequence<1,
2780 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2781 1,
2782 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2783
2784 static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
2785 constexpr auto EMThreads =
2786 CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
2787 constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2788 constexpr auto ENThreads =
2789 CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
2790 static_for<0, num_access, 1>{}([&](auto access_id) {
2791 // make sure it's safe to write to LDS
2793
2794 auto dstidx = sfc_cde_block.GetIndex(access_id);
2795 const index_t c_token_pos =
2796 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
2797 static_for<0, EMRepeats, 1>{}([&](auto m0) {
2798 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2799 IndexType token_offset = fused_token & 0xffffff;
2800 if constexpr(IsInputGemm)
2801 {
2802 token_offset = token_offset * problem.TopK + (fused_token >> 24);
2803 }
2804 scatter_offsets(m0) = static_cast<IndexType>(token_offset) * problem.N;
2805 });
2806
2808
2809 // each thread write its data from VGPR to LDS
2810 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2811 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2812 c_thread_buf_fp32,
2813 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2814 c_shuffle_block_buf);
2815
2816 // make sure it's safe to read from LDS
2818
2819 // each block copy its data from LDS to global
2820 cde_block_copy_lds_and_global.Run(
2821 c_ds_desc_refs,
2822 c_ds_buf_refs,
2823 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2824 tie(c_grid_buf),
2825 scatter_offsets);
2826
2827 if constexpr(access_id < num_access - 1)
2828 {
2829 constexpr auto cde_lds_and_global_step =
2830 sfc_cde_block.GetForwardStep(access_id);
2831
2832 // move on Ds
2833 static_for<0, NumDTensor, 1>{}([&](auto i) {
2834 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2835 c_ds_desc_refs, i + I1, cde_lds_and_global_step);
2836 });
2837
2838 // move on E
2839 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2840 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2841 I0,
2842 cde_lds_and_global_step);
2843 }
2844 });
2845 }
2846 }
2847#endif
2848};
2849
2850} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr auto lcm(X x, Y y)
Definition utility/math.hpp:198
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition utility/sequence.hpp:928
__device__ index_t get_warp_local_1d_id()
Definition get_id.hpp:45
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
__host__ __device__ constexpr auto container_concat(const X &x, const Ys &... ys)
Definition utility/container_helper.hpp:320
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition gridwise_moe_mx_gemm_bns.hpp:48
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
__host__ __device__ constexpr auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:185
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_moe_mx_gemm.hpp:90
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Vgpr
Definition amd_address_space.hpp:20
__host__ __device__ PY c_style_pointer_cast(PX p_x)
Definition c_style_pointer_cast.hpp:15
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
Activation
Definition gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition gridwise_moe_gemm.hpp:32
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:84
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition utility/sequence.hpp:925
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
constexpr index_t packed_size_v
Definition data_type.hpp:411
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
constexpr auto BlockGemmMXNBSPipeline_Selector()
Definition blockwise_gemm_pipeline_xdlops_mx_moe_nbs_selector.hpp:37
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
Definition gridwise_moe_mx_gemm_bns.hpp:654
const ADataType * p_a_grid
Definition gridwise_moe_mx_gemm_bns.hpp:717
const index_t * p_sorted_token_ids
Definition gridwise_moe_mx_gemm_bns.hpp:714
const index_t * p_sorted_expert_ids
Definition gridwise_moe_mx_gemm_bns.hpp:715
const index_t * p_max_token_id
Definition gridwise_moe_mx_gemm_bns.hpp:716
DsGridPointer p_ds_grid
Definition gridwise_moe_mx_gemm_bns.hpp:721
const CElementwiseOperation c_element_op
Definition gridwise_moe_mx_gemm_bns.hpp:726
CDataType * p_c_grid
Definition gridwise_moe_mx_gemm_bns.hpp:722
const AElementwiseOperation a_element_op
Definition gridwise_moe_mx_gemm_bns.hpp:724
const BScaleDataType * p_b_scale_grid
Definition gridwise_moe_mx_gemm_bns.hpp:720
const BDataType * p_b_grid
Definition gridwise_moe_mx_gemm_bns.hpp:719
const AScaleDataType * p_a_scale_grid
Definition gridwise_moe_mx_gemm_bns.hpp:718
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition gridwise_moe_mx_gemm_bns.hpp:655
const BElementwiseOperation b_element_op
Definition gridwise_moe_mx_gemm_bns.hpp:725
index_t M
Definition gridwise_moe_mx_gemm_bns.hpp:632
index_t TopK
Definition gridwise_moe_mx_gemm_bns.hpp:631
index_t NPadded
Definition gridwise_moe_mx_gemm_bns.hpp:643
index_t MPadded
Definition gridwise_moe_mx_gemm_bns.hpp:642
index_t StrideScaleB
Definition gridwise_moe_mx_gemm_bns.hpp:638
index_t StrideScaleA
Definition gridwise_moe_mx_gemm_bns.hpp:636
index_t MBlock
Definition gridwise_moe_mx_gemm_bns.hpp:648
index_t StrideC
Definition gridwise_moe_mx_gemm_bns.hpp:640
index_t AK0
Definition gridwise_moe_mx_gemm_bns.hpp:646
index_t KPadded
Definition gridwise_moe_mx_gemm_bns.hpp:645
index_t NBlock
Definition gridwise_moe_mx_gemm_bns.hpp:649
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition gridwise_moe_mx_gemm_bns.hpp:583
index_t StrideA
Definition gridwise_moe_mx_gemm_bns.hpp:635
index_t StrideB
Definition gridwise_moe_mx_gemm_bns.hpp:637
index_t KBatch
Definition gridwise_moe_mx_gemm_bns.hpp:641
index_t BK0
Definition gridwise_moe_mx_gemm_bns.hpp:647
index_t KRead
Definition gridwise_moe_mx_gemm_bns.hpp:644
__host__ void Print() const
Definition gridwise_moe_mx_gemm_bns.hpp:618
index_t K
Definition gridwise_moe_mx_gemm_bns.hpp:634
index_t N
Definition gridwise_moe_mx_gemm_bns.hpp:633
index_t NumTokens
Definition gridwise_moe_mx_gemm_bns.hpp:630
std::array< index_t, NumDTensor > StrideDs
Definition gridwise_moe_mx_gemm_bns.hpp:639
index_t a_k_split_offset
Definition gridwise_moe_mx_gemm_bns.hpp:784
index_t b_k_split_offset
Definition gridwise_moe_mx_gemm_bns.hpp:785
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition gridwise_moe_mx_gemm_bns.hpp:731
index_t b_scale_k_split_offset
Definition gridwise_moe_mx_gemm_bns.hpp:787
index_t a_scale_k_split_offset
Definition gridwise_moe_mx_gemm_bns.hpp:786
Definition gridwise_moe_mx_gemm_bns.hpp:179
remove_cvref_t< decltype(BlockGemmMXNBSPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, GemmAccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition gridwise_moe_mx_gemm_bns.hpp:1041
__host__ static __device__ constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition gridwise_moe_mx_gemm_bns.hpp:1304
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition gridwise_moe_mx_gemm_bns.hpp:1333
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
__host__ static __device__ constexpr index_t At(index_t I)
Definition utility/sequence.hpp:53
Definition tensor_space_filling_curve.hpp:20
Definition static_buffer.hpp:75
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition utility/tuple.hpp:117
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Unsigned representation of a conventional biased Float32 exponent.
Definition utility/e8m0.hpp:26
Definition data_type.hpp:42
Definition data_type.hpp:187
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1041
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:1087
Definition dtype_vector.hpp:10
#define CK_ENV(name)
Definition utility/env.hpp:129