warp_gemm_attribute_mfma.hpp Source File

warp_gemm_attribute_mfma.hpp Source File#

Composable Kernel: warp_gemm_attribute_mfma.hpp Source File
warp_gemm_attribute_mfma.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// Number of groups of consecutive elements to fill in a ABKLane
13{
14 Single = 1,
15 Double = 2,
16 Quad = 4,
18};
19
20template <typename WarpGemmAttributeMfmaImpl_,
23{
25 static constexpr auto AttrNumAccess = AttrNumAccess_;
26 static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
27
28 using ADataType = typename Impl::ADataType;
29 using BDataType = typename Impl::BDataType;
30 using CDataType = typename Impl::CDataType;
31
32 using AVecType = typename Impl::AVecType;
33 using BVecType = typename Impl::BVecType;
34 using CVecType = typename Impl::CVecType;
35
36 static constexpr index_t kM = Impl::kM;
37 static constexpr index_t kN = Impl::kN;
38 static constexpr index_t kK = Impl::kK;
39 static constexpr index_t kKPerThread = Impl::kABKPerLane;
40 static constexpr index_t kCMLane = Impl::kCMLane;
41
42 CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
43
44 static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
45 "Multi-block WarpGemmAttributeMfmaImpl is not supported");
46
47 template <index_t kMNLane>
48 static constexpr auto get_warp_dstr_encoding()
49 {
50 if constexpr(AttrNumAccessV == 1)
51 {
58 sequence<1>>{};
59 }
60 else
61 {
62 static_assert(kKPerThread % AttrNumAccessV == 0,
63 "kKPerThread must be divisible by NumAccess");
67 sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
72 }
73 }
76
85
86 // c_vec += a_vec * b_vec
87 template <bool post_nop_ = false>
89 const AVecType& a_vec,
90 const BVecType& b_vec,
91 bool_constant<post_nop_> = {}) const
92 {
93 Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
94 }
95
96 // c_vec += a_vec * b_vec
97 template <index_t opselA, index_t opselB, bool post_nop_ = false>
99 const AVecType& a_vec,
100 const int32_t& a_scale,
101 const BVecType& b_vec,
102 const int32_t& b_scale,
103 bool_constant<post_nop_> = {}) const
104 {
105 Impl{}.template operator()<opselA, opselB>(
106 c_vec, a_vec, a_scale, b_vec, b_scale, bool_constant<post_nop_>{});
107 }
108
109 // c_vec = a_vec * b_vec
110 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
111 {
112 return Impl{}(a_vec, b_vec);
113 }
114
115 // c_vec = a_vec * b_vec
116 template <index_t opselA, index_t opselB>
118 const int32_t& a_scale,
119 const BVecType& b_vec,
120 const int32_t& b_scale) const
121 {
122 auto c_vec = Impl{}.template operator()<opselA, opselB>(a_vec, a_scale, b_vec, b_scale);
123 }
124};
125
126template <typename WarpGemmAttributeMfmaImpl_,
127 index_t kKIter,
130{
131 static_assert(kKIter > 0, "wrong!");
132
134 static constexpr auto AttrNumAccess = AttrNumAccess_;
135 static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
136
137 using ADataType = typename Impl::ADataType;
138 using BDataType = typename Impl::BDataType;
139 using CDataType = typename Impl::CDataType;
140
141 using AVecType =
143 using BVecType =
145 using CVecType = typename Impl::CVecType;
146
147 static constexpr index_t kM = Impl::kM;
148 static constexpr index_t kN = Impl::kN;
149 static constexpr index_t kK = Impl::kK * kKIter;
150 static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
151 static constexpr index_t kCMLane = Impl::kCMLane;
152
153 CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
154
155 static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
156 "Multi-block on both M & N directions is not supported");
157
159 {
160 if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
161 {
162 if constexpr(AttrNumAccessV == 1)
163 {
171 sequence<1>>{};
172 }
173 else
174 {
175 static_assert(kKPerThread % AttrNumAccessV == 0,
176 "kKPerThread must be divisible by NumAccess");
181 Impl::kABKLane,
182 Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
187 }
188 }
189 else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
190 {
191 static_assert(AttrNumAccessV == 1,
192 "Multiple access is not supported when using multi-block");
193 // each M blocks share the same data
201 sequence<1>>{};
202 }
203 else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
204 {
205 static_assert(AttrNumAccessV == 1,
206 "Multiple access is not supported when using multi-block");
207 // single block to multi-block thread mapping
215 sequence<1>>{};
216 }
217 }
218
220 {
221 if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
222 {
223 if constexpr(AttrNumAccessV == 1)
224 {
232 sequence<1>>{};
233 }
234 else
235 {
236
237 static_assert(kKPerThread % AttrNumAccessV == 0,
238 "kKPerThread must be divisible by NumAccess");
243 Impl::kABKLane,
244 Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
249 }
250 }
251 else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
252 {
253 static_assert(AttrNumAccessV == 1,
254 "Multiple access is not supported when using multi-block");
255 // single block to multi-block thread mapping
263 sequence<1>>{};
264 }
265 else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
266 {
267 static_assert(AttrNumAccessV == 1,
268 "Multiple access is not supported when using multi-block");
269 // each N blocks share the same data
277 sequence<1>>{};
278 }
279 }
280
282 {
283 if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
284 {
293 }
294 else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
295 {
304 }
305 else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
306 {
309 tuple<
316 }
317 }
318
320
322
324
325 // c_vec += a_vec * b_vec
326 template <bool post_nop_ = false>
328 const AVecType& a_vec,
329 const BVecType& b_vec,
330 bool_constant<post_nop_> = {}) const
331 {
334
335 static_for<0, kKIter, 1>{}([&](auto iKIter) {
336 Impl{}(c_vec,
337 reinterpret_cast<const buf_a&>(a_vec)
338 .template get_as<typename Impl::AVecType>()[iKIter],
339 reinterpret_cast<const buf_b&>(b_vec)
340 .template get_as<typename Impl::BVecType>()[iKIter],
342 });
343 }
344
345 template <index_t iKIter, bool post_nop_ = false>
347 const AVecType& a_vec,
348 const BVecType& b_vec,
350 bool_constant<post_nop_> = {}) const
351 {
354
355 static_assert(iKIter < kKIter);
356
357 // static_for<0, kKIter, 1>{}([&](auto iKIter) {
358 Impl{}(c_vec,
359 reinterpret_cast<const buf_a&>(a_vec)
360 .template get_as<typename Impl::AVecType>()[iKIter],
361 reinterpret_cast<const buf_b&>(b_vec)
362 .template get_as<typename Impl::BVecType>()[iKIter],
364 //});
365 }
366
367 // c_vec = a_vec * b_vec
368 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
369 {
370 constexpr auto I0 = number<0>{};
373
374 // c = a * b
375 auto c_vec = Impl{}(
376 reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
377 reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
378
379 // c += a * b
380 static_for<1, kKIter, 1>{}([&](auto iKIter) {
381 Impl{}(c_vec,
382 reinterpret_cast<const buf_a&>(a_vec)
383 .template get_as<typename Impl::AVecType>()[iKIter],
384 reinterpret_cast<const buf_b&>(b_vec)
385 .template get_as<typename Impl::BVecType>()[iKIter]);
386 });
387
388 return c_vec;
389 }
390};
391
392template <typename WarpGemmAttributeMfmaImpl_,
395{
397 static constexpr auto AttrNumAccess = AttrNumAccess_;
398 static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
399
400 using ADataType = typename Impl::BDataType;
401 using BDataType = typename Impl::ADataType;
402 using CDataType = typename Impl::CDataType;
403
404 using AVecType = typename Impl::BVecType;
405 using BVecType = typename Impl::AVecType;
406 using CVecType = typename Impl::CVecType;
407
408 static constexpr index_t kM = Impl::kN;
409 static constexpr index_t kN = Impl::kM;
410 static constexpr index_t kK = Impl::kK;
411 static constexpr index_t kKPerThread = Impl::kABKPerLane;
412 static constexpr index_t kCMLane = Impl::kCMLane;
413
414 CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
415
416 static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
417 "Multi-block WarpGemmAttributeMfmaImpl is not supported");
418
419 template <index_t kMNLane>
420 static constexpr auto get_warp_dstr_encoding()
421 {
422 if constexpr(AttrNumAccessV == 1)
423 {
430 sequence<1>>{};
431 }
432 else
433 {
434 static_assert(kKPerThread % AttrNumAccessV == 0,
435 "kKPerThread must be divisible by NumAccess");
439 sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
444 }
445 }
448
457
458 // c_vec += a_vec * b_vec
459 template <bool post_nop_ = false>
461 const AVecType& a_vec,
462 const BVecType& b_vec,
463 bool_constant<post_nop_> = {}) const
464 {
465 // swap A and B
466 Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
467 }
468
469 // c_vec = a_vec * b_vec
470 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
471 {
472 // swap A and B
473 return Impl{}(b_vec, a_vec);
474 }
475};
476
477template <typename WarpGemmAttributeMfmaImpl_, index_t SFactor_ = 2>
479{
481
482 using ADataType = typename Impl::BDataType;
483 using BDataType = typename Impl::ADataType;
484 using CDataType = typename Impl::CDataType;
485
486 using AVecType = typename Impl::BVecType;
487 using BVecType = typename Impl::AVecType;
488 using CVecType = typename Impl::CVecType;
489
490 static constexpr index_t kM = Impl::kN;
491 static constexpr index_t kN = Impl::kM;
492 static constexpr index_t kK = Impl::kK;
493 static constexpr index_t kKPerThread = Impl::kABKPerLane;
494 static constexpr index_t SFactor = SFactor_; // group how many CM1 together
495
496 CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
497
498 static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
499 "Multi-block WarpGemmAttributeMfmaImpl is not supported");
500
508#if 0
511 tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
512 Impl::kABKLane,
513 2,
514 Impl::kABKPerLane>,
520
524 sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
529#else
530 // TODO: more test not only 32x32
533 tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
534 Impl::kCMLane,
535 SFactor,
536 Impl::kCM1PerLane>,
542
546 sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>>,
551#endif
552 template <bool post_nop_ = false>
553 // c_vec += a_vec * b_vec
555 const AVecType& a_vec,
556 const BVecType& b_vec,
557 bool_constant<post_nop_> = {}) const
558 {
559 // swap A and B
560 Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
561 }
562
563 // c_vec = a_vec * b_vec
564 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
565 {
566 // swap A and B
567 return Impl{}(b_vec, a_vec);
568 }
569};
570
571template <typename WarpGemmAttributeMfmaImpl_,
572 index_t kKIter,
575{
577 static constexpr auto AttrNumAccess = AttrNumAccess_;
578
579 // swap A and B
580 using ADataType = typename Impl::BDataType;
581 using BDataType = typename Impl::ADataType;
582 using CDataType = typename Impl::CDataType;
583
584 using AVecType =
586 using BVecType =
588 using CVecType = typename Impl::CVecType;
589
590 static constexpr index_t kM = Impl::kN;
591 static constexpr index_t kN = Impl::kM;
592 static constexpr index_t kK = Impl::kK * kKIter;
593 static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
594
595 CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
596
597 static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
598 "Multi-block on both M & N directions is not supported");
599
605
611
613 {
614 if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
615 {
624 }
625 else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
626 {
635 }
636 else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
637 {
640 tuple<
647 }
648 }
649
651
653
655
656 template <bool post_nop_ = false>
657 // c_vec += a_vec * b_vec
659 const AVecType& a_vec,
660 const BVecType& b_vec,
661 bool_constant<post_nop_> = {}) const
662 {
665 // swap A and B, value and type
666 static_for<0, kKIter, 1>{}([&](auto iKIter) {
667 Impl{}(c_vec,
668 reinterpret_cast<const buf_b&>(b_vec)
669 .template get_as<typename Impl::BVecType>()[iKIter],
670 reinterpret_cast<const buf_a&>(a_vec)
671 .template get_as<typename Impl::AVecType>()[iKIter],
673 });
674 }
675
676 template <index_t iKIter, bool post_nop_ = false>
677 // c_vec += a_vec * b_vec
679 const AVecType& a_vec,
680 const BVecType& b_vec,
682 bool_constant<post_nop_> = {}) const
683 {
686
687 static_assert(iKIter < kKIter);
688 // swap A and B, value and type
689 // static_for<0, kKIter, 1>{}([&](auto iKIter) {
690 Impl{}(c_vec,
691 reinterpret_cast<const buf_b&>(b_vec)
692 .template get_as<typename Impl::BVecType>()[iKIter],
693 reinterpret_cast<const buf_a&>(a_vec)
694 .template get_as<typename Impl::AVecType>()[iKIter],
696 //});
697 }
698
699 // c_vec = a_vec * b_vec
700 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
701 {
702 constexpr auto I0 = number<0>{};
705
706 // swap A and B, value and type
707 auto c_vec = Impl{}(
708 reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
709 reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
710
711 static_for<1, kKIter, 1>{}([&](auto iKIter) {
712 Impl{}(c_vec,
713 reinterpret_cast<const buf_b&>(b_vec)
714 .template get_as<typename Impl::BVecType>()[iKIter],
715 reinterpret_cast<const buf_a&>(a_vec)
716 .template get_as<typename Impl::AVecType>()[iKIter]);
717 });
718
719 return c_vec;
720 }
721};
722
723template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
725{
727
728 // swap A and B
729 using ADataType = typename Impl::BDataType;
730 using BDataType = typename Impl::ADataType;
731 using CDataType = typename Impl::CDataType;
732
733 using AVecType =
735 using BVecType =
737 using CVecType = typename Impl::CVecType;
738
739 static constexpr index_t kM = Impl::kN;
740 static constexpr index_t kN = Impl::kM;
741 static constexpr index_t kK = Impl::kK * kKIter;
742 static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
743 static constexpr index_t SFactor = SFactor_; // group how many CM1 together
744
745 CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
746
747 static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
748 "Multi-block WarpGemmAttributeMfmaImpl is not supported");
749
757#if 0
760 tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
761 Impl::kABKLane,
762 2,
763 Impl::kABKPerLane>,
769
773 sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
778#else
779 // TODO: more test not only 32x32
782 tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
783 Impl::kCMLane,
784 SFactor,
785 Impl::kCM1PerLane>,
791
795 sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>>,
800#endif
801 // c_vec += a_vec * b_vec
802 template <bool post_nop_ = false>
804 const AVecType& a_vec,
805 const BVecType& b_vec,
806 bool_constant<post_nop_> = {}) const
807 {
810 // swap A and B, value and type
811 static_for<0, kKIter, 1>{}([&](auto iKIter) {
812 Impl{}(c_vec,
813 reinterpret_cast<const buf_b&>(b_vec)
814 .template get_as<typename Impl::BVecType>()[iKIter],
815 reinterpret_cast<const buf_a&>(a_vec)
816 .template get_as<typename Impl::AVecType>()[iKIter],
818 });
819 }
820
821 template <index_t iKIter, bool post_nop_ = false>
823 const AVecType& a_vec,
824 const BVecType& b_vec,
826 bool_constant<post_nop_> = {}) const
827 {
830
831 static_assert(iKIter < kKIter);
832 // swap A and B, value and type
833 // static_for<0, kKIter, 1>{}([&](auto iKIter) {
834 Impl{}(c_vec,
835 reinterpret_cast<const buf_b&>(b_vec)
836 .template get_as<typename Impl::BVecType>()[iKIter],
837 reinterpret_cast<const buf_a&>(a_vec)
838 .template get_as<typename Impl::AVecType>()[iKIter],
840 //});
841 }
842
843 // c_vec = a_vec * b_vec
844 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
845 {
848 constexpr auto I0 = number<0>{};
849
850 // swap A and B, value and type
851 auto c_vec = Impl{}(
852 reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
853 reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
854
855 static_for<1, kKIter, 1>{}([&](auto iKIter) {
856 Impl{}(c_vec,
857 reinterpret_cast<const buf_b&>(b_vec)
858 .template get_as<typename Impl::BVecType>()[iKIter],
859 reinterpret_cast<const buf_a&>(a_vec)
860 .template get_as<typename Impl::AVecType>()[iKIter]);
861 });
862
863 return c_vec;
864 }
865};
866
867template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
869{
871
872 using ADataType = typename Impl::ADataType;
873 using BDataType = typename Impl::BDataType;
874 using CDataType = typename Impl::CDataType;
875
876 using AVecType =
878 using BVecType =
880 using CVecType = typename Impl::CVecType;
881
882 static constexpr index_t kM = Impl::kM;
883 static constexpr index_t kN = Impl::kN;
884 static constexpr index_t kK = Impl::kK * kKIter;
885 static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
886 static constexpr index_t SFactor = SFactor_; // group how many CM1 together
887
888 CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
889
890 static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
891 "Multi-block WarpGemmAttributeMfmaImpl is not supported");
892
895 tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
896 Impl::kCMLane,
897 SFactor,
898 Impl::kCM1PerLane>,
904
912
915 tuple<sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>,
921
922 // c_vec += a_vec * b_vec
923 template <bool post_nop_ = false>
925 const AVecType& a_vec,
926 const BVecType& b_vec,
927 bool_constant<post_nop_> = {}) const
928 {
931
932 static_for<0, kKIter, 1>{}([&](auto iKIter) {
933 Impl{}(c_vec,
934 reinterpret_cast<const buf_a&>(a_vec)
935 .template get_as<typename Impl::AVecType>()[iKIter],
936 reinterpret_cast<const buf_b&>(b_vec)
937 .template get_as<typename Impl::BVecType>()[iKIter],
939 });
940 }
941
942 template <index_t iKIter, bool post_nop_ = false>
944 const AVecType& a_vec,
945 const BVecType& b_vec,
947 bool_constant<post_nop_> = {}) const
948 {
951
952 static_assert(iKIter < kKIter);
953
954 // static_for<0, kKIter, 1>{}([&](auto iKIter) {
955 Impl{}(c_vec,
956 reinterpret_cast<const buf_a&>(a_vec)
957 .template get_as<typename Impl::AVecType>()[iKIter],
958 reinterpret_cast<const buf_b&>(b_vec)
959 .template get_as<typename Impl::BVecType>()[iKIter],
961 //});
962 }
963
964 // c_vec = a_vec * b_vec
965 CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
966 {
967 constexpr auto I0 = number<0>{};
970
971 auto c_vec = Impl{}(
972 reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
973 reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
974
975 static_for<1, kKIter, 1>{}([&](auto iKIter) {
976 Impl{}(c_vec,
977 reinterpret_cast<const buf_a&>(a_vec)
978 .template get_as<typename Impl::AVecType>()[iKIter],
979 reinterpret_cast<const buf_b&>(b_vec)
980 .template get_as<typename Impl::BVecType>()[iKIter]);
981 });
982
983 return c_vec;
984 }
985};
986
987} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
WGAttrNumAccessEnum
Definition warp_gemm_attribute_mfma.hpp:13
@ Invalid
Definition warp_gemm_attribute_mfma.hpp:17
@ Single
Definition warp_gemm_attribute_mfma.hpp:14
@ Double
Definition warp_gemm_attribute_mfma.hpp:15
@ Quad
Definition warp_gemm_attribute_mfma.hpp:16
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
int32_t int32_t
Definition integer.hpp:10
typename impl::ext_vector< T, N >::type ext_vector_t
Definition vector_type.hpp:84
int32_t index_t
Definition integer.hpp:9
Definition warp_gemm_attribute_mfma.hpp:23
static constexpr auto get_warp_dstr_encoding()
Definition warp_gemm_attribute_mfma.hpp:48
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const int32_t &a_scale, const BVecType &b_vec, const int32_t &b_scale) const
Definition warp_gemm_attribute_mfma.hpp:117
static constexpr index_t kK
Definition warp_gemm_attribute_mfma.hpp:38
typename Impl::BDataType BDataType
Definition warp_gemm_attribute_mfma.hpp:29
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma.hpp:40
static CK_TILE_HOST_DEVICE constexpr auto get_num_of_access()
Definition warp_gemm_attribute_mfma.hpp:42
typename Impl::AVecType AVecType
Definition warp_gemm_attribute_mfma.hpp:32
typename Impl::CVecType CVecType
Definition warp_gemm_attribute_mfma.hpp:34
decltype(get_warp_dstr_encoding< Impl::kAMLane >()) AWarpDstrEncoding
Definition warp_gemm_attribute_mfma.hpp:74
static constexpr auto AttrNumAccess
Definition warp_gemm_attribute_mfma.hpp:25
typename Impl::CDataType CDataType
Definition warp_gemm_attribute_mfma.hpp:30
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma.hpp:110
static constexpr auto AttrNumAccessV
Definition warp_gemm_attribute_mfma.hpp:26
static constexpr index_t kM
Definition warp_gemm_attribute_mfma.hpp:36
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition warp_gemm_attribute_mfma.hpp:24
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma.hpp:88
static constexpr index_t kKPerThread
Definition warp_gemm_attribute_mfma.hpp:39
static constexpr index_t kN
Definition warp_gemm_attribute_mfma.hpp:37
tile_distribution_encoding< sequence<>, tuple< sequence< Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane >, sequence< Impl::kCNLane > >, tuple< sequence< 1, 2 > >, tuple< sequence< 1, 0 > >, sequence< 1, 1 >, sequence< 0, 2 > > CWarpDstrEncoding
Definition warp_gemm_attribute_mfma.hpp:77
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const int32_t &a_scale, const BVecType &b_vec, const int32_t &b_scale, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma.hpp:98
typename Impl::ADataType ADataType
Definition warp_gemm_attribute_mfma.hpp:28
decltype(get_warp_dstr_encoding< Impl::kBNLane >()) BWarpDstrEncoding
Definition warp_gemm_attribute_mfma.hpp:75
typename Impl::BVecType BVecType
Definition warp_gemm_attribute_mfma.hpp:33
Definition warp_gemm_attribute_mfma.hpp:869
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition warp_gemm_attribute_mfma.hpp:876
tile_distribution_encoding< sequence<>, tuple< sequence< Impl::kAMLane/(Impl::kCMLane *SFactor *Impl::kCM1PerLane), Impl::kCMLane, SFactor, Impl::kCM1PerLane >, sequence< Impl::kABKLane, Impl::kABKPerLane *kKIter > >, tuple< sequence< 2, 1, 1, 1, 1 > >, tuple< sequence< 0, 0, 2, 1, 3 > >, sequence< 2 >, sequence< 1 > > AWarpDstrEncoding
Definition warp_gemm_attribute_mfma.hpp:893
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition warp_gemm_attribute_mfma.hpp:870
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma.hpp:943
static constexpr index_t kN
Definition warp_gemm_attribute_mfma.hpp:883
static constexpr index_t kM
Definition warp_gemm_attribute_mfma.hpp:882
static constexpr index_t kK
Definition warp_gemm_attribute_mfma.hpp:884
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma.hpp:965
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition warp_gemm_attribute_mfma.hpp:878
tile_distribution_encoding< sequence<>, tuple< sequence< Impl::kCM0PerLane/SFactor, Impl::kCMLane, Impl::kCM1PerLane *SFactor >, sequence< Impl::kCNLane > >, tuple< sequence< 1, 2 > >, tuple< sequence< 1, 0 > >, sequence< 1, 1 >, sequence< 0, 2 > > CWarpDstrEncoding
Definition warp_gemm_attribute_mfma.hpp:913
static constexpr index_t kKPerThread
Definition warp_gemm_attribute_mfma.hpp:885
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma.hpp:924
typename Impl::BDataType BDataType
Definition warp_gemm_attribute_mfma.hpp:873
static CK_TILE_HOST_DEVICE constexpr auto get_num_of_access()
Definition warp_gemm_attribute_mfma.hpp:888
typename Impl::CVecType CVecType
Definition warp_gemm_attribute_mfma.hpp:880
static constexpr index_t SFactor
Definition warp_gemm_attribute_mfma.hpp:886
tile_distribution_encoding< sequence<>, tuple< sequence< Impl::kBNLane >, sequence< Impl::kABKLane, Impl::kABKPerLane *kKIter > >, tuple< sequence< 2, 1 > >, tuple< sequence< 0, 0 > >, sequence< 2 >, sequence< 1 > > BWarpDstrEncoding
Definition warp_gemm_attribute_mfma.hpp:905
typename Impl::ADataType ADataType
Definition warp_gemm_attribute_mfma.hpp:872
typename Impl::CDataType CDataType
Definition warp_gemm_attribute_mfma.hpp:874
typename Impl::BDataType ADataType
Definition warp_gemm_attribute_mfma.hpp:729
typename Impl::ADataType BDataType
Definition warp_gemm_attribute_mfma.hpp:730
static constexpr index_t kKPerThread
Definition warp_gemm_attribute_mfma.hpp:742
static constexpr index_t SFactor
Definition warp_gemm_attribute_mfma.hpp:743
tile_distribution_encoding< sequence<>, tuple< sequence< Impl::kAMLane/(Impl::kCMLane *SFactor *Impl::kCM1PerLane), Impl::kCMLane, SFactor, Impl::kCM1PerLane >, sequence< Impl::kABKLane, Impl::kABKPerLane *kKIter > >, tuple< sequence< 2, 1, 1, 1, 1 > >, tuple< sequence< 0, 0, 2, 1, 3 > >, sequence< 2 >, sequence< 1 > > BWarpDstrEncoding
Definition warp_gemm_attribute_mfma.hpp:780
static constexpr index_t kK
Definition warp_gemm_attribute_mfma.hpp:741
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma.hpp:803
typename Impl::CVecType CVecType
Definition warp_gemm_attribute_mfma.hpp:737
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition warp_gemm_attribute_mfma.hpp:733
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition warp_gemm_attribute_mfma.hpp:726
typename Impl::CDataType CDataType
Definition warp_gemm_attribute_mfma.hpp:731
static constexpr index_t kN
Definition warp_gemm_attribute_mfma.hpp:740
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma.hpp:844
tile_distribution_encoding< sequence<>, tuple< sequence< Impl::kCNLane >, sequence< Impl::kCM0PerLane/SFactor, Impl::kCMLane, Impl::kCM1PerLane *SFactor > >, tuple< sequence< 2, 1 > >, tuple< sequence< 1, 0 > >, sequence< 2, 2 >, sequence< 0, 2 > > CWarpDstrEncoding
Definition warp_gemm_attribute_mfma.hpp:792
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma.hpp:822
static constexpr index_t kM
Definition warp_gemm_attribute_mfma.hpp:739
tile_distribution_encoding< sequence<>, tuple< sequence< Impl::kBNLane >, sequence< Impl::kABKLane, Impl::kABKPerLane *kKIter > >, tuple< sequence< 2, 1 > >, tuple< sequence< 0, 0 > >, sequence< 2 >, sequence< 1 > > AWarpDstrEncoding
Definition warp_gemm_attribute_mfma.hpp:750
static CK_TILE_HOST_DEVICE constexpr auto get_num_of_access()
Definition warp_gemm_attribute_mfma.hpp:745
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition warp_gemm_attribute_mfma.hpp:735
Definition warp_gemm_attribute_mfma.hpp:575
typename Impl::CVecType CVecType
Definition warp_gemm_attribute_mfma.hpp:588
static CK_TILE_DEVICE constexpr auto get_bwarp_dstr_encoding()
Definition warp_gemm_attribute_mfma.hpp:606
static CK_TILE_HOST_DEVICE constexpr auto get_num_of_access()
Definition warp_gemm_attribute_mfma.hpp:595
static constexpr index_t kKPerThread
Definition warp_gemm_attribute_mfma.hpp:593
typename Impl::CDataType CDataType
Definition warp_gemm_attribute_mfma.hpp:582
static CK_TILE_DEVICE constexpr auto get_awarp_dstr_encoding()
Definition warp_gemm_attribute_mfma.hpp:600
typename Impl::BDataType ADataType
Definition warp_gemm_attribute_mfma.hpp:580
static constexpr index_t kM
Definition warp_gemm_attribute_mfma.hpp:590
static constexpr index_t kK
Definition warp_gemm_attribute_mfma.hpp:592
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma.hpp:678
static constexpr index_t kN
Definition warp_gemm_attribute_mfma.hpp:591
decltype(get_awarp_dstr_encoding()) AWarpDstrEncoding
Definition warp_gemm_attribute_mfma.hpp:650
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma.hpp:700
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition warp_gemm_attribute_mfma.hpp:586
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition warp_gemm_attribute_mfma.hpp:584
decltype(get_bwarp_dstr_encoding()) BWarpDstrEncoding
Definition warp_gemm_attribute_mfma.hpp:652
decltype(get_cwarp_dstr_encoding()) CWarpDstrEncoding
Definition warp_gemm_attribute_mfma.hpp:654
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition warp_gemm_attribute_mfma.hpp:576
static constexpr auto AttrNumAccess
Definition warp_gemm_attribute_mfma.hpp:577
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma.hpp:658
static CK_TILE_DEVICE constexpr auto get_cwarp_dstr_encoding()
Definition warp_gemm_attribute_mfma.hpp:612
typename Impl::ADataType BDataType
Definition warp_gemm_attribute_mfma.hpp:581
Definition warp_gemm_attribute_mfma.hpp:130
static constexpr auto AttrNumAccess
Definition warp_gemm_attribute_mfma.hpp:134
static CK_TILE_DEVICE constexpr auto get_awarp_dstr_encoding()
Definition warp_gemm_attribute_mfma.hpp:158
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma.hpp:151
decltype(get_bwarp_dstr_encoding()) BWarpDstrEncoding
Definition warp_gemm_attribute_mfma.hpp:321
typename Impl::CVecType CVecType
Definition warp_gemm_attribute_mfma.hpp:145
typename Impl::BDataType BDataType
Definition warp_gemm_attribute_mfma.hpp:138
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma.hpp:346
static CK_TILE_HOST_DEVICE constexpr auto get_num_of_access()
Definition warp_gemm_attribute_mfma.hpp:153
decltype(get_cwarp_dstr_encoding()) CWarpDstrEncoding
Definition warp_gemm_attribute_mfma.hpp:323
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma.hpp:368
typename Impl::CDataType CDataType
Definition warp_gemm_attribute_mfma.hpp:139
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma.hpp:327
typename Impl::ADataType ADataType
Definition warp_gemm_attribute_mfma.hpp:137
static constexpr index_t kM
Definition warp_gemm_attribute_mfma.hpp:147
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition warp_gemm_attribute_mfma.hpp:141
static CK_TILE_DEVICE constexpr auto get_bwarp_dstr_encoding()
Definition warp_gemm_attribute_mfma.hpp:219
static constexpr index_t kK
Definition warp_gemm_attribute_mfma.hpp:149
static CK_TILE_DEVICE constexpr auto get_cwarp_dstr_encoding()
Definition warp_gemm_attribute_mfma.hpp:281
static constexpr auto AttrNumAccessV
Definition warp_gemm_attribute_mfma.hpp:135
decltype(get_awarp_dstr_encoding()) AWarpDstrEncoding
Definition warp_gemm_attribute_mfma.hpp:319
static constexpr index_t kN
Definition warp_gemm_attribute_mfma.hpp:148
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition warp_gemm_attribute_mfma.hpp:133
static constexpr index_t kKPerThread
Definition warp_gemm_attribute_mfma.hpp:150
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition warp_gemm_attribute_mfma.hpp:143
Definition warp_gemm_attribute_mfma.hpp:479
tile_distribution_encoding< sequence<>, tuple< sequence< Impl::kCNLane >, sequence< Impl::kCM0PerLane/SFactor, Impl::kCMLane, Impl::kCM1PerLane *SFactor > >, tuple< sequence< 2, 1 > >, tuple< sequence< 1, 0 > >, sequence< 2, 2 >, sequence< 0, 2 > > CWarpDstrEncoding
Definition warp_gemm_attribute_mfma.hpp:543
static constexpr index_t kN
Definition warp_gemm_attribute_mfma.hpp:491
static constexpr index_t SFactor
Definition warp_gemm_attribute_mfma.hpp:494
tile_distribution_encoding< sequence<>, tuple< sequence< Impl::kAMLane/(Impl::kCMLane *SFactor *Impl::kCM1PerLane), Impl::kCMLane, SFactor, Impl::kCM1PerLane >, sequence< Impl::kABKLane, Impl::kABKPerLane > >, tuple< sequence< 2, 1, 1, 1, 1 > >, tuple< sequence< 0, 0, 2, 1, 3 > >, sequence< 2 >, sequence< 1 > > BWarpDstrEncoding
Definition warp_gemm_attribute_mfma.hpp:531
typename Impl::BVecType AVecType
Definition warp_gemm_attribute_mfma.hpp:486
typename Impl::AVecType BVecType
Definition warp_gemm_attribute_mfma.hpp:487
typename Impl::ADataType BDataType
Definition warp_gemm_attribute_mfma.hpp:483
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma.hpp:564
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition warp_gemm_attribute_mfma.hpp:480
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma.hpp:554
static constexpr index_t kKPerThread
Definition warp_gemm_attribute_mfma.hpp:493
typename Impl::CDataType CDataType
Definition warp_gemm_attribute_mfma.hpp:484
tile_distribution_encoding< sequence<>, tuple< sequence< Impl::kBNLane >, sequence< Impl::kABKLane, Impl::kABKPerLane > >, tuple< sequence< 2, 1 > >, tuple< sequence< 0, 0 > >, sequence< 2 >, sequence< 1 > > AWarpDstrEncoding
Definition warp_gemm_attribute_mfma.hpp:501
static constexpr index_t kM
Definition warp_gemm_attribute_mfma.hpp:490
typename Impl::BDataType ADataType
Definition warp_gemm_attribute_mfma.hpp:482
static constexpr index_t kK
Definition warp_gemm_attribute_mfma.hpp:492
static CK_TILE_HOST_DEVICE constexpr auto get_num_of_access()
Definition warp_gemm_attribute_mfma.hpp:496
typename Impl::CVecType CVecType
Definition warp_gemm_attribute_mfma.hpp:488
Definition warp_gemm_attribute_mfma.hpp:395
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition warp_gemm_attribute_mfma.hpp:470
static constexpr index_t kCMLane
Definition warp_gemm_attribute_mfma.hpp:412
static CK_TILE_HOST_DEVICE constexpr auto get_num_of_access()
Definition warp_gemm_attribute_mfma.hpp:414
typename Impl::BDataType ADataType
Definition warp_gemm_attribute_mfma.hpp:400
typename Impl::AVecType BVecType
Definition warp_gemm_attribute_mfma.hpp:405
static constexpr index_t kKPerThread
Definition warp_gemm_attribute_mfma.hpp:411
decltype(get_warp_dstr_encoding< Impl::kAMLane >()) BWarpDstrEncoding
Definition warp_gemm_attribute_mfma.hpp:447
typename Impl::BVecType AVecType
Definition warp_gemm_attribute_mfma.hpp:404
static constexpr index_t kK
Definition warp_gemm_attribute_mfma.hpp:410
tile_distribution_encoding< sequence<>, tuple< sequence< Impl::kCNLane >, sequence< Impl::kCM0PerLane, Impl::kCMLane, Impl::kCM1PerLane > >, tuple< sequence< 2, 1 > >, tuple< sequence< 1, 0 > >, sequence< 2, 2 >, sequence< 0, 2 > > CWarpDstrEncoding
Definition warp_gemm_attribute_mfma.hpp:449
typename Impl::ADataType BDataType
Definition warp_gemm_attribute_mfma.hpp:401
static constexpr auto AttrNumAccessV
Definition warp_gemm_attribute_mfma.hpp:398
static constexpr auto get_warp_dstr_encoding()
Definition warp_gemm_attribute_mfma.hpp:420
typename Impl::CDataType CDataType
Definition warp_gemm_attribute_mfma.hpp:402
static constexpr auto AttrNumAccess
Definition warp_gemm_attribute_mfma.hpp:397
typename Impl::CVecType CVecType
Definition warp_gemm_attribute_mfma.hpp:406
static constexpr index_t kN
Definition warp_gemm_attribute_mfma.hpp:409
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition warp_gemm_attribute_mfma.hpp:396
decltype(get_warp_dstr_encoding< Impl::kBNLane >()) AWarpDstrEncoding
Definition warp_gemm_attribute_mfma.hpp:446
static constexpr index_t kM
Definition warp_gemm_attribute_mfma.hpp:408
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition warp_gemm_attribute_mfma.hpp:460
Definition tile/core/utility/functional.hpp:43
Definition tile/core/utility/debug.hpp:67
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192