fused_moegemm_pipeline_flatmm_policy.hpp Source File

fused_moegemm_pipeline_flatmm_policy.hpp Source File#

Composable Kernel: fused_moegemm_pipeline_flatmm_policy.hpp Source File
fused_moegemm_pipeline_flatmm_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
11
12namespace ck_tile {
13
15{
17 {
18 // TODO: always 1 dword
19 return 1;
20 }
21
22 template <typename Problem>
23 CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_A()
24 {
25 // using async
26 constexpr index_t copy_bytes = 4 * GetAsyncCopyDwords();
27 constexpr index_t data_bytes = sizeof(typename Problem::ADataType);
28 static_assert(copy_bytes % data_bytes == 0);
29 return copy_bytes / data_bytes;
30 }
31
32 template <typename Problem>
33 CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_G()
34 {
35 constexpr index_t copy_bytes = [&]() { return 16; }();
36 constexpr index_t data_bytes = sizeof(typename Problem::GDataType);
37 static_assert(copy_bytes % data_bytes == 0);
38 return copy_bytes / data_bytes;
39 }
40
41 template <typename Problem>
42 CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_D()
43 {
44 constexpr index_t copy_bytes = [&]() { return 16; }();
45 constexpr index_t data_bytes = sizeof(typename Problem::DDataType);
46 static_assert(copy_bytes % data_bytes == 0);
47 return copy_bytes / data_bytes;
48 }
49
50 template <typename Problem>
51 CK_TILE_HOST_DEVICE static constexpr auto GetAlignment_O()
52 {
53 if constexpr(Problem::Traits::OAtomic == 1)
54 {
55 // pack fp16/bf16 atomic
56 static_assert(sizeof(typename Problem::ODataType) == 2);
57 return 2;
58 }
59 else if constexpr(Problem::Traits::OAtomic == 2)
60 {
61 // fp32 atomic
62 return 1;
63 }
64 else
65 {
66 return 16 / sizeof(typename Problem::ODataType);
67 }
68 }
69
70 template <typename DataType_>
71 CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack()
72 {
73 // TODO: this is for 3d layout
74 return 16 / sizeof(remove_cvref_t<DataType_>);
75 }
76
77 template <typename Problem>
82
83 // used for bridge LDS shuffle
84 template <typename Problem>
85 CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPack_Y()
86 {
87 // TODO: this should match mfma layout
88 return 16 / sizeof(typename Problem::YDataType);
89 }
90
91 template <typename Problem>
93 {
94 constexpr auto a_sld_desc = MakeLdsLoadDesc_A<Problem>();
95 constexpr auto a_sst_desc = MakeLdsStoreDesc_A<Problem>();
96 static_assert(a_sld_desc.get_element_space_size() == a_sst_desc.get_element_space_size());
97 return a_sld_desc.get_element_space_size();
98 }
99
100 template <typename Problem>
102 {
103 constexpr auto bridge_sld_desc = MakeBridgeLdsLoadDesc<Problem>();
104 constexpr auto bridge_sst_desc = MakeBridgeLdsStoreDesc<Problem>();
105 static_assert(bridge_sld_desc.get_element_space_size() ==
106 bridge_sst_desc.get_element_space_size());
107 return bridge_sld_desc.get_element_space_size();
108 }
109
110 template <typename Problem>
112 {
113 constexpr index_t a_lds = GetSmemSize_A<Problem>();
114 constexpr index_t bridge_lds = GetSmemSize_Bridge<Problem>();
115 return max(a_lds, bridge_lds);
116 }
117
118 template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
120 {
121 constexpr index_t K_vec = Alignment;
122 constexpr index_t K_rem = KPerBlock / K_vec;
123
124 if constexpr(get_warp_size() < K_rem)
125 {
126 static_assert(K_rem % get_warp_size() == 0);
127 constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k
128 constexpr index_t K_wav = K_rem / get_warp_size();
129 static_assert(K_wav <= NumWarps, "not not support thread has repeat along K yet");
130 constexpr index_t M_wav = NumWarps / K_wav;
131 static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
132 constexpr index_t M_rep = MPerBlock / M_wav;
133
141 sequence<0, 2>>{});
142 }
143 else
144 {
145 constexpr index_t K_lan = K_rem;
146 constexpr index_t M_lan = get_warp_size() / K_lan;
147 constexpr index_t M_wav = NumWarps;
148 static_assert(MPerBlock % (M_lan * M_wav) == 0,
149 "this tile size is too small please check");
150 constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
158 sequence<0, 1>>{});
159 }
160 }
161
162 // optimized version for async, not same as simple MXK dist(pay attention!!)
163 template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
165 {
166 constexpr index_t K_vec = Alignment;
167 constexpr index_t K_rem = KPerBlock / K_vec;
168
169 if constexpr(get_warp_size() <= K_rem)
170 {
171 static_assert(K_rem % get_warp_size() == 0);
172 constexpr index_t K_lan = get_warp_size(); // lane within same wave is along gemm-k
173 constexpr index_t K_wav = K_rem / get_warp_size();
174 static_assert(K_wav <= NumWarps, "do not support thread has repeat along K yet");
175 constexpr index_t M_wav = NumWarps / K_wav;
176 static_assert(MPerBlock % M_wav == 0, "this tile size is too small please check");
177 constexpr index_t M_rep = MPerBlock / M_wav;
178 // NOTE: no swap, but hard to avoid LDS bank conflict
186 sequence<0, 2>>{});
187 }
188 else
189 {
190 constexpr index_t K_lan = K_rem;
191 constexpr index_t M_lan = get_warp_size() / K_lan;
192 constexpr index_t M_wav = NumWarps;
193 static_assert(MPerBlock % (M_lan * M_wav) == 0,
194 "this tile size is too small please check");
195 constexpr index_t M_rep = MPerBlock / (M_lan * M_wav);
196 // NOTE: swapped for LDS load bank conflict free
200 // Note M_wave(num waves) is the fastest dim, different from sipmle 2d
201 // distribution
206 sequence<0, 1>>{});
207 }
208 }
209
210 template <index_t WarpPerBlock_N_,
211 index_t WarpPerBlock_K_,
212 index_t Repeat_N_,
213 index_t Repeat_K_,
214 index_t WarpSize_,
215 index_t Alignment_>
228
229 template <typename Problem>
231 {
232 constexpr index_t Block_M_ = Problem::BlockShape::Block_M0;
233 constexpr index_t Block_K_ = Problem::BlockShape::Block_K0;
234 constexpr index_t NumWarps_ = Problem::BlockShape::NumWarps;
235 constexpr index_t Alignment_ = GetAlignment_A<Problem>();
237 Block_K_,
238 NumWarps_,
239 Alignment_>();
240 }
241
242 template <typename Problem>
244 {
245 constexpr auto PermuteEnum = Problem::Traits::PermuteEnum;
246 // constexpr index_t hidden_radio_0 = Problem::Traits::IsGateOnly ? 1 : 2;
247 using S_ = typename Problem::BlockShape;
248 if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
249 {
250 // number<S_::WarpPerBlock_N0>{}.rrr();
251 // number<S_::Repeat_N0>{}.eee();
252 return MakeGlobalTileDistribution_Nr_Kr_W<S_::WarpPerBlock_N0,
253 S_::WarpPerBlock_K0,
254 S_::Repeat_N0,
255 S_::Repeat_K0,
258 }
259 }
260
261 template <typename Problem>
263 {
264 constexpr auto PermuteEnum = Problem::Traits::PermuteEnum;
265 using S_ = typename Problem::BlockShape;
266 if constexpr(PermuteEnum == FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten)
267 {
268 return MakeGlobalTileDistribution_Nr_Kr_W<S_::WarpPerBlock_N1,
269 S_::WarpPerBlock_K1,
270 S_::Repeat_N1,
271 S_::Repeat_K1,
274 }
275 }
276
277 template <typename Problem>
279 {
282 // using CDataType = typename WarpGemm::CDataType;
283
284 constexpr auto c_block_outer_dstr_encoding =
292
293 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
294 c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
295 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
296 return c_block_dstr;
297 }
298
299 template <typename Problem>
301 {
302 // A async->LDS
303 constexpr index_t Block_M = Problem::BlockShape::Block_M0;
304 constexpr index_t Block_K = Problem::BlockShape::Block_K0;
305 // constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
306 constexpr index_t WarpSize = ck_tile::get_warp_size();
307 constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
308
309 constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
310 constexpr index_t KVector = GetAlignment_A<Problem>(); // async copy 1 dword
311 constexpr index_t KPad = KPack; // pad between warps
312
313 static_assert(Block_K % KVector == 0);
314 constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
315 if constexpr(LanesPerK >= WarpSize)
316 {
317 // need multiple waves to load K
318 static_assert(LanesPerK % WarpSize == 0);
319 constexpr index_t wavesPerK = LanesPerK / WarpSize;
320 if constexpr(wavesPerK > NumWarps)
321 {
322 // TODO: need multiple issues along K to load all data
323 }
324 else
325 {
326 constexpr index_t wavesPerM = NumWarps / wavesPerK;
327 constexpr index_t NumIssues = Block_M / wavesPerM;
328 constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
330 number<wavesPerM>{}, // m1
331 number<wavesPerK>{}, // k0
332 number<WarpSize>{}, // k1
333 number<KVector>{}), // k2
334 make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
335 number<wavesPerK*(WarpSize * KVector + KPad)>{}, // m1
337 number<KVector>{}, // k1
338 number<1>{}), // k2
339 number<KVector>{}, // lds store vector(actually no explicit store)
340 number<1>{});
341
342 constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
343 lds_block_desc_0,
350
351 return lds_block_desc_issues_warps_lanes;
352 }
353 }
354 else
355 {
356 // lanes within a wave load different M but same K
357 static_assert(WarpSize % LanesPerK == 0);
358 constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m
359 constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
360
361 constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
363 number<LaneGroups>{}, // m1
364 number<NumWarps>{}, // m2
365 number<LanesPerK>{}, // k0
366 number<KVector>{}), // k1
367 make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
368 number<Block_K>{}, // m1
370 number<KVector>{}, // k0
371 number<1>{}), // k1
372 number<KVector>{}, // lds store vector(actually no explicit store)
373 number<1>{});
374
375 constexpr auto lds_block_desc_issues_warps_lanes = transform_tensor_descriptor(
376 lds_block_desc_0,
383
384 return lds_block_desc_issues_warps_lanes;
385 }
386 }
387
388 template <typename Problem>
390 {
391 // A async->LDS
392 // Note that, this descriptor is only to construct the layout inside LDS
393 // in real Gemm pipeline, ds_read may not follow this pattern
394 // (may follow that in tile_distribution)
395 // below code is almost the same as SmemStore dist, with difference:
396 // 1). modify the GuaranteedLastDimensionVectorLength of naive tensor desc
397 // 2). return discriptor is in NxK 2d layout
398 constexpr index_t Block_M = Problem::BlockShape::Block_M0;
399 constexpr index_t Block_K = Problem::BlockShape::Block_K0;
400 // constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
401 constexpr index_t WarpSize = ck_tile::get_warp_size();
402 constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
403
404 constexpr index_t KPack = GetSmemKPack_A<Problem>(); // LDS
405 constexpr index_t KVector = GetAlignment_A<Problem>(); // async copy 1 dword
406 constexpr index_t KPad = KPack; // pad between warps
407
408 static_assert(Block_K % KVector == 0);
409 constexpr index_t LanesPerK = Block_K / KVector; // how many thread loading K
410 if constexpr(LanesPerK >= WarpSize)
411 {
412 // need multiple waves to load K
413 static_assert(LanesPerK % WarpSize == 0);
414 constexpr index_t wavesPerK = LanesPerK / WarpSize;
415 if constexpr(wavesPerK >= NumWarps)
416 {
417 // TODO: need multiple issues along K to load all data
418 }
419 else
420 {
421 constexpr index_t wavesPerM = NumWarps / wavesPerK;
422 constexpr index_t NumIssues = Block_M / wavesPerM;
423 constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
425 number<wavesPerM>{}, // m1
426 number<wavesPerK>{}, // k0
427 number<WarpSize>{}, // k1
428 number<KVector>{}), // k2
429 make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
430 number<wavesPerK*(WarpSize * KVector + KPad)>{}, // m1
432 number<KVector>{}, // k1
433 number<1>{}), // k2
434 number<KPack>{}, // lds load vector
435 number<1>{});
436
437 constexpr auto lds_desc_m_k = transform_tensor_descriptor(
438 lds_block_desc_0,
445
446 return lds_desc_m_k;
447 }
448 }
449 else
450 {
451 // lanes within a wave load different M but same K
452 static_assert(WarpSize % LanesPerK == 0);
453 constexpr index_t LaneGroups = WarpSize / LanesPerK; // along m
454 constexpr index_t NumIssues = Block_M / (LaneGroups * NumWarps);
455
456 constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
458 number<LaneGroups>{}, // m1
459 number<NumWarps>{}, // m2
460 number<LanesPerK>{}, // k0
461 number<KVector>{}), // k1
462 make_tuple(number<NumWarps*(WarpSize * KVector + KPad)>{}, // m0
463 number<Block_K>{}, // m1
465 number<KVector>{}, // k0
466 number<1>{}), // k1
467 number<KPack>{}, // lds load vector
468 number<1>{});
469
470 constexpr auto lds_desc_m_k = transform_tensor_descriptor(
471 lds_block_desc_0,
478
479 return lds_desc_m_k;
480 }
481 }
482
483 template <typename Problem>
485 {
486 constexpr index_t Block_M = Problem::BlockShape::Block_M0;
487 constexpr index_t Block_N = Problem::BlockShape::Block_N0;
488
489 constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
490 constexpr index_t KPad = 0; // pad between warps
491
492 constexpr auto desc =
496 number<1>{});
497 return desc;
498 }
499
500 template <typename Problem>
502 {
503 constexpr index_t Block_M = Problem::BlockShape::Block_M0;
504 constexpr index_t Block_N = Problem::BlockShape::Block_N0;
505
506 constexpr index_t KVector = GetSmemKPack_Y<Problem>(); // async copy 1 dword
507 constexpr index_t KPad = 0; // KVector; // pad between warps
508
509 constexpr auto desc =
513 number<1>{});
514 return desc;
515 }
516
517 template <typename Problem>
519 {
520 constexpr index_t WarpPerBlock_N = Problem::BlockShape::WarpPerBlock_N0;
521 constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N0;
522 constexpr index_t Repeat_M = Problem::BlockShape::Repeat_M0;
523
524 constexpr index_t kAMLane = 16;
525 constexpr index_t kABKLane = 4;
526 constexpr index_t kABKPerLane = 4;
527
528 constexpr index_t KPack = kABKPerLane;
529
530 constexpr auto lds_block_desc_0 = make_naive_tensor_descriptor(
532 number<Repeat_N>{}, // n
534 number<kABKLane>{}, // n
535 number<kAMLane>{}, // m
536 number<KPack>{}), // n
541 number<KPack>{}, // m
542 number<1>{}), // n
543 number<KPack>{}, // lds store vector(actually no explicit store)
544 number<1>{});
545
546 constexpr auto desc = transform_tensor_descriptor(
547 lds_block_desc_0,
552 number<KPack>{}))),
555
556 return desc;
557 }
558
559 template <typename Problem>
560 CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm0()
561 {
562 using S_ = typename Problem::BlockShape;
563 // A is vgpr, B is agpr. But since we transposed, so also need swap this
564 // TODO: this is ugly
565 constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
566 // TODO: ugly
567 if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
568 std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
569 S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
570 {
573 2>>{};
574 }
575 else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::int8_t> &&
576 std::is_same_v<typename Problem::GDataType, ck_tile::int8_t> &&
577 S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
578 {
581 2>>{};
582 }
583 }
584
585 template <typename Problem>
586 CK_TILE_HOST_DEVICE static constexpr auto GetSequencer_0()
587 {
588 // this function return seq<...> used to identify gld/sld/valu... inside mfma sequence
589 // the purpose is to hide thoes instructions under mfma
590 // every value inside seq<...> is a mask, indicating a specific operation
591 using S_ = typename Problem::BlockShape;
595 if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
596 std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
597 S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 &&
598 S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
599 S_::Block_N1 == 128)
600 {
601 // Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
602 // gld_a 8x ds_read_b128 sld_a total 64 slot :)
603 // clang-format off
604 constexpr auto seq_all =
605 // 0 1 2 3 4 5 6 7
610 GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 4
611 GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 5
612 GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 6
613 GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 7
614 return seq_all;
615 // clang-format on
616 }
617 else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
618 std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
619 S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 &&
620 S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
621 S_::Block_N1 == 128)
622 {
623 // Total 32 instructions, 16 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
624 // gld_a 8x ds_read_b128 sld_a total 64 slot :)
625 // clang-format off
626 constexpr auto seq_all =
627 // 0 1 2 3 4 5 6 7
631 GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A, GLD_B, SLD_A>{}; // 3
632 return seq_all;
633 // clang-format on
634 }
635 }
636
637 template <typename Problem>
638 CK_TILE_HOST_DEVICE static constexpr auto GetSequencer_1()
639 {
640 // this function return seq<...> used to identify gld/sld/valu... inside mfma sequence
641 // the purpose is to hide thoes instructions under mfma
642 // every value inside seq<...> is a mask, indicating a specific operation
643 using S_ = typename Problem::BlockShape;
646 if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
647 std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
648 S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 &&
649 S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
650 S_::Block_N1 == 128)
651 {
652 // Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
653 // gld_a 8x ds_read_b128 sld_a total 64 slot :)
654 // clang-format off
655 constexpr auto seq_all =
656 // 0 1 2 3 4 5 6 7
659 GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 2
660 GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 3
661 GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 4
662 GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 5
663 GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 6
664 GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 7
665 return seq_all;
666 // clang-format on
667 }
668 else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
669 std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
670 S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 &&
671 S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
672 S_::Block_N1 == 128)
673 {
674 // Total 64 instructions, 32 buffer-load-dwordx4 gld_b, 8x buffer-load-dwordx1-async
675 // gld_a 8x ds_read_b128 sld_a total 64 slot :)
676 // clang-format off
677 constexpr auto seq_all =
678 // 0 1 2 3 4 5 6 7
681 GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0, // 2
682 GLD_B, 0, GLD_B, 0, GLD_B, 0, GLD_B, 0>{}; // 3
683 return seq_all;
684 // clang-format on
685 }
686 }
687
688 template <typename Problem>
689 CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemm1()
690 {
691 using S_ = typename Problem::BlockShape;
692 constexpr auto wg_ctrl = WGAttrCtlEnum::Raw_avv;
693 // TODO: ugly
694 if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
695 std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
696 S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
697 {
700 2>>{};
701 }
702 else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::int8_t> &&
703 std::is_same_v<typename Problem::DDataType, ck_tile::int8_t> &&
704 S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
705 {
708 2>>{};
709 }
710 }
711
712 template <typename Problem>
714 {
717 using CDataType = typename WarpGemm::CDataType;
718
719 constexpr auto c_block_outer_dstr_encoding =
727
728 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
729 c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
730 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
731 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
732 return c_block_tensor;
733 }
734
735 template <typename Problem>
737 {
740 using CDataType = typename WarpGemm::CDataType;
741
742 constexpr auto c_block_outer_dstr_encoding =
750
751 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
752 c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
753 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
754 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
755 return c_block_tensor;
756 }
757
758 // this is used as A matrix for 2nd gemm
759 template <typename Problem>
761 {
764
765 // TODO: all waves a along different N, but same M
766 constexpr auto y_outer_dstr_enc =
773
774 constexpr auto y_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
775 y_outer_dstr_enc, typename WarpGemm::AWarpDstrEncoding{});
776 constexpr auto y_block_dstr = make_static_tile_distribution(y_block_dstr_encode);
777 return y_block_dstr;
778 }
779
780 template <typename Problem>
781 CK_TILE_HOST_DEVICE static constexpr auto MakeYBlockTile()
782 {
783 constexpr auto y_block_dstr = MakeYTileDistribution<Problem>();
784 auto y_block_tensor =
786 return y_block_tensor;
787 }
788
789 template <typename Problem>
790 CK_TILE_HOST_DEVICE static constexpr auto GetUK_0()
791 {
792 using S_ = typename Problem::BlockShape;
793 if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
794 std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
795 S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
796 S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
797 {
799 }
800 else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::fp16_t> &&
801 std::is_same_v<typename Problem::GDataType, ck_tile::fp16_t> &&
802 S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
803 S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
804 {
806 }
807 }
808
809 template <typename Problem>
810 CK_TILE_HOST_DEVICE static constexpr auto GetUK_1()
811 {
812 using S_ = typename Problem::BlockShape;
813 using T_ = typename Problem::Traits;
814 if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
815 std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
816 std::is_same_v<typename Problem::TopkWeightDataType, float> &&
817 S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
818 S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
819 T_::PipeInterleave == false)
820 {
822 // return FlatmmSn_32x128x512_1x4x1_16x16x32_BF16_itl{};
823 }
824 else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
825 std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
826 std::is_same_v<typename Problem::TopkWeightDataType, float> &&
827 S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
828 S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
829 T_::PipeInterleave == false)
830 {
832 // return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16_itl{};
833 }
834 else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
835 std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
836 std::is_same_v<typename Problem::TopkWeightDataType, float> &&
837 S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
838 S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
839 T_::PipeInterleave == true)
840 {
841 // return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
843 }
844 else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
845 std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
846 std::is_same_v<typename Problem::TopkWeightDataType, float> &&
847 S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
848 S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
849 T_::PipeInterleave == true)
850 {
851 // return FlatmmSn_32x128x512_1x4x1_16x16x32_FP16{};
853 }
854 }
855};
856} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
@ Raw_avv
Definition warp_gemm_attribute_mfma_impl.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
@ GST_O
Definition fused_moegemm_traits.hpp:48
@ GLD_B
Definition fused_moegemm_traits.hpp:45
@ SLD_A
Definition fused_moegemm_traits.hpp:42
@ GLD_A
Definition fused_moegemm_traits.hpp:44
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
@ b_nr_kr_waveflatten
Definition fused_moegemm_traits.hpp:16
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:401
Definition flatmm_32x512x128_1x4x1_16x16x32.hpp:540
Definition flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:18
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:74
Definition flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:265
Definition flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:318
Definition fused_moegemm_pipeline_flatmm_policy.hpp:15
static CK_TILE_HOST_DEVICE constexpr auto GetWarpGemm1()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:689
static CK_TILE_HOST_DEVICE constexpr auto MakeYTileDistribution()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:760
static CK_TILE_HOST_DEVICE constexpr auto GetUK_1()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:810
static CK_TILE_HOST_DEVICE constexpr auto GetSequencer_1()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:638
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPack()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:71
static CK_TILE_HOST_DEVICE constexpr auto MakeGlobalTileDistribution_SimpleMxK_Async()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:164
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPack_A()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:78
static CK_TILE_HOST_DEVICE constexpr auto MakeGlobalTileDistribution_A()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:230
static CK_TILE_HOST_DEVICE constexpr auto GetUK_0()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:790
static CK_TILE_HOST_DEVICE constexpr auto MakeGlobalTileDistribution_SimpleMxK()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:119
static CK_TILE_HOST_DEVICE constexpr auto GetAlignment_O()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:51
static CK_TILE_HOST_DEVICE constexpr auto MakeBridgeLdsLoadDesc()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:484
static CK_TILE_HOST_DEVICE constexpr auto MakeCBlockTile_Gemm1()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:736
static CK_TILE_HOST_DEVICE constexpr auto GetAlignment_G()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:33
static CK_TILE_HOST_DEVICE constexpr auto GetAlignment_D()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:42
static CK_TILE_HOST_DEVICE constexpr auto MakeGlobalTileDistribution_O()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:278
static CK_TILE_HOST_DEVICE constexpr auto MakeGlobalTileDistribution_Nr_Kr_W()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:216
static CK_TILE_HOST_DEVICE constexpr auto MakeBridgeLdsStoreDesc()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:501
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize_A()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:92
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:111
static CK_TILE_HOST_DEVICE constexpr auto MakeLdsLoadDesc_A()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:389
static CK_TILE_HOST_DEVICE constexpr auto MakeGlobalTileDistribution_D()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:262
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize_Bridge()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:101
static CK_TILE_HOST_DEVICE constexpr auto MakeCBlockTile_Gemm0()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:713
static CK_TILE_HOST_DEVICE constexpr auto MakeBridgeLdsStoreForUKDesc()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:518
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPack_Y()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:85
static CK_TILE_HOST_DEVICE constexpr auto GetAlignment_A()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:23
static CK_TILE_HOST_DEVICE constexpr index_t GetAsyncCopyDwords()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:16
static CK_TILE_HOST_DEVICE constexpr auto MakeGlobalTileDistribution_G()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:243
static CK_TILE_HOST_DEVICE constexpr auto MakeYBlockTile()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:781
static CK_TILE_HOST_DEVICE constexpr auto MakeLdsStoreDesc_A()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:300
static CK_TILE_HOST_DEVICE constexpr auto GetSequencer_0()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:586
static CK_TILE_HOST_DEVICE constexpr auto GetWarpGemm0()
Definition fused_moegemm_pipeline_flatmm_policy.hpp:560
Definition warp_gemm_attribute_mfma_impl.hpp:1820
Definition warp_gemm_attribute_mfma_impl.hpp:577
Definition warp_gemm_impl.hpp:11
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192