device_batched_gemm_xdl_fpAintB_b_scale.hpp Source File

device_batched_gemm_xdl_fpAintB_b_scale.hpp Source File#

Composable Kernel: device_batched_gemm_xdl_fpAintB_b_scale.hpp Source File
device_batched_gemm_xdl_fpAintB_b_scale.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
19
20namespace ck {
21
22// Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
23// kernel function Blockers:
24// 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
25// two lds chunks.
26// 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
27// buffer when we declare __shared__ inside blkgemmpipe
28template <typename GridwiseGemm,
29 typename BatchedGemmArg,
30 bool HasMainKBlockLoop,
31 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
32 index_t MinimumOccupancy = 1,
34__global__ void
35#if CK_USE_LAUNCH_BOUNDS
36__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
37#endif
39{
40#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
41 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
42 {
43 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
44
45 const index_t g_idx = blockIdx.z % karg.Batch;
46 const index_t k_idx = blockIdx.z / karg.Batch;
47
48 const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
49 const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
50 const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
51 const auto b_scale_batch_offset =
52 karg.compute_ptr_offset_of_batch.GetSacleBPtrOffset(g_idx);
53
54 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx);
55
56 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
57 karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset,
58 karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset,
59 karg.p_c_grid + c_batch_offset + splitk_batch_offset.c_reduce_offset,
60 karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset,
61 p_shared,
62 karg);
63 }
64#else
65 ignore = karg;
66#endif // end of if (defined(__gfx9__))
67}
68
69template <typename GridwiseGemm,
70 typename BatchedGemmArg,
71 bool HasMainKBlockLoop,
72 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
73 index_t MinimumOccupancy = 1,
75__global__ void
76#if CK_USE_LAUNCH_BOUNDS
77__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
78#endif
80{
81#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
82 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
83 {
84 // Pass two lds pointer is the key to tell compiler that ds_read/write
85 // operate on different lds chunk at same time without order dependecy
86 __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
87 __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
88
89 const index_t g_idx = blockIdx.z % karg.Batch;
90 const index_t k_idx = blockIdx.z / karg.Batch;
91
92 const auto a_batch_offset = karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx);
93 const auto b_batch_offset = karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx);
94 const auto c_batch_offset = karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx);
95 const auto b_scale_batch_offset =
96 karg.compute_ptr_offset_of_batch.GetSacleBPtrOffset(g_idx);
97
98 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, k_idx);
99
100 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
101 karg.p_a_grid + a_batch_offset + splitk_batch_offset.a_k_split_offset,
102 karg.p_b_grid + b_batch_offset + splitk_batch_offset.b_k_split_offset,
103 karg.p_c_grid + c_batch_offset + splitk_batch_offset.c_reduce_offset,
104 karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset,
105 p_shared_0,
106 p_shared_1,
107 karg);
108 }
109#else
110 ignore = karg;
111#endif // end of if (defined(__gfx9__))
112}
113
114namespace tensor_operation {
115namespace device {
116
117template <typename ALayout,
118 typename BLayout,
119 typename CLayout,
120 typename ADataType,
121 typename BDataType,
122 typename BScaleDataType,
123 typename CDataType,
124 typename GemmAccDataType,
125 typename CShuffleDataType,
126 typename AElementwiseOperation,
127 typename BElementwiseOperation,
128 typename CElementwiseOperation,
129 GemmSpecialization GemmSpec,
130 index_t BlockSize,
131 index_t ScaleBlockN, // scale block for N
132 index_t ScaleBlockK, // scale block for K
133 index_t MPerBlock,
134 index_t NPerBlock,
135 index_t KPerBlock,
136 index_t AK1,
137 index_t BK1,
138 index_t MPerXDL,
139 index_t NPerXDL,
140 index_t MXdlPerWave,
141 index_t NXdlPerWave,
142 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
143 typename ABlockTransferThreadClusterArrangeOrder,
144 typename ABlockTransferSrcAccessOrder,
145 index_t ABlockTransferSrcVectorDim,
146 index_t ABlockTransferSrcScalarPerVector,
147 index_t ABlockTransferDstScalarPerVector_AK1,
148 bool ABlockLdsExtraM,
149 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
150 typename BBlockTransferThreadClusterArrangeOrder,
151 typename BBlockTransferSrcAccessOrder,
152 index_t BBlockTransferSrcVectorDim,
153 index_t BBlockTransferSrcScalarPerVector,
154 index_t BBlockTransferDstScalarPerVector_BK1,
155 bool BBlockLdsExtraN,
156 index_t CShuffleMXdlPerWavePerShuffle,
157 index_t CShuffleNXdlPerWavePerShuffle,
158 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
159 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
162 typename ComputeTypeA = CDataType,
163 typename ComputeTypeB = ComputeTypeA,
164 bool PermuteA = false,
165 bool PermuteB = false>
167 : public DeviceBatchedGemmV2BScale<ALayout,
168 BLayout,
169 CLayout,
170 ADataType,
171 BDataType,
172 BScaleDataType,
173 CDataType,
174 ScaleBlockN,
175 ScaleBlockK,
176 AElementwiseOperation,
177 BElementwiseOperation,
178 CElementwiseOperation>
179{
181 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
182 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
183
184 // GridwiseGemm
185 template <index_t NXdlPerWave_>
187 ALayout,
188 BLayout,
189 CLayout,
190 ADataType,
191 BDataType,
192 GemmAccDataType,
193 CShuffleDataType,
194 CDataType,
195 AElementwiseOperation,
196 BElementwiseOperation,
197 CElementwiseOperation,
198 GemmSpec,
199 BlockSize,
200 ScaleBlockN,
201 ScaleBlockK,
202 MPerBlock,
203 NPerBlock,
204 KPerBlock,
205 AK1,
206 BK1,
207 MPerXDL,
208 NPerXDL,
209 MXdlPerWave,
210 NXdlPerWave_,
211 ABlockTransferThreadClusterLengths_AK0_M_AK1,
212 ABlockTransferThreadClusterArrangeOrder,
213 ABlockTransferSrcAccessOrder,
214 ABlockTransferSrcVectorDim,
215 ABlockTransferSrcScalarPerVector,
216 ABlockTransferDstScalarPerVector_AK1,
217 false,
218 ABlockLdsExtraM,
219 BBlockTransferThreadClusterLengths_BK0_N_BK1,
220 BBlockTransferThreadClusterArrangeOrder,
221 BBlockTransferSrcAccessOrder,
222 BBlockTransferSrcVectorDim,
223 BBlockTransferSrcScalarPerVector,
224 BBlockTransferDstScalarPerVector_BK1,
225 false,
226 BBlockLdsExtraN,
227 CShuffleMXdlPerWavePerShuffle,
228 CShuffleNXdlPerWavePerShuffle,
229 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
230 CShuffleBlockTransferScalarPerVector_NPerBlock,
231 BlkGemmPipeSched,
232 BlkGemmPipelineVer,
233 ComputeTypeA,
234 ComputeTypeB,
235 PermuteA,
236 PermuteB>;
239
240 static constexpr index_t APackedSize = []() {
242 return 2;
243 else
244 return 1;
245 }();
246
247 static constexpr index_t BPackedSize = []() {
249 return 2;
250 else
251 return 1;
252 }();
253
255 {
257 index_t BatchStrideB,
258 index_t BatchStrideC,
259 index_t BatchStrideScaleB)
260 : BatchStrideA_(BatchStrideA),
261 BatchStrideB_(BatchStrideB),
262 BatchStrideC_(BatchStrideC),
263 BatchStrideScaleB_(BatchStrideScaleB)
264 {
265 }
266
267 __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
268 {
269 return g_idx * static_cast<long_index_t>(BatchStrideA_);
270 }
271
272 __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
273 {
274 return g_idx * static_cast<long_index_t>(BatchStrideB_) / BPackedSize;
275 }
276
277 __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
278 {
279 return g_idx * static_cast<long_index_t>(BatchStrideC_);
280 }
281 __host__ __device__ constexpr long_index_t GetSacleBPtrOffset(index_t g_idx) const
282 {
283 return g_idx * static_cast<long_index_t>(BatchStrideScaleB_);
284 }
285
286 private:
287 index_t BatchStrideA_;
288 index_t BatchStrideB_;
289 index_t BatchStrideC_;
290 index_t BatchStrideScaleB_;
291 };
292
293 template <typename GridwiseGemm>
294 struct ArgumentBase : public GridwiseGemm::Argument
295 {
298
299 ArgumentBase(const ADataType* p_a_grid_,
300 const BDataType* p_b_grid_,
301 CDataType* p_c_grid_,
302 index_t M_,
303 index_t N_,
304 index_t K_,
305 index_t StrideA_,
306 index_t StrideB_,
307 index_t StrideC_,
308 index_t StrideScaleB_,
309 index_t BatchStrideA_,
310 index_t BatchStrideB_,
311 index_t BatchStrideC_,
312 index_t BatchStrideScaleB_,
313 const BScaleDataType* p_b_scale_grid_,
314 index_t Batch_,
315 index_t KBatch_,
316 AElementwiseOperation a_element_op_,
317 BElementwiseOperation b_element_op_,
318 CElementwiseOperation c_element_op_)
319 : GridwiseGemm::Argument(p_a_grid_,
320 p_b_grid_,
321 p_c_grid_,
322 M_,
323 N_,
324 K_,
325 StrideA_,
326 StrideB_,
327 StrideC_,
328 StrideScaleB_,
329 p_b_scale_grid_,
330 KBatch_, // KBatch
331 a_element_op_,
332 b_element_op_,
333 c_element_op_),
334 Batch(Batch_),
336 BatchStrideA_, BatchStrideB_, BatchStrideC_, BatchStrideScaleB_)
337 {
338 }
339 };
341
342 // Invoker
343 struct Invoker : public BaseInvoker
344 {
345 template <typename GridwiseGemm>
347 const StreamConfig& stream_config = StreamConfig{})
348 {
349 using DeviceArgument = ArgumentBase<GridwiseGemm>;
350 if(stream_config.log_level_ > 0)
351 {
352 arg.Print();
353 }
354
355 if(!GridwiseGemm::CheckValidity(arg))
356 {
357 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
358 }
359
360 index_t gdx, gdy, gdz;
361 std::tie(gdx, gdy, gdz) =
362 GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.Batch * arg.KBatch);
363
364 float ave_time = 0;
365
366 index_t k_grain = arg.KBatch * KPerBlock;
367 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
368
369 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
370
371 const auto Run = [&](const auto& kernel) {
372 if(stream_config.flush_cache)
373 {
374 DeviceArgument arg_ = arg;
375
376 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
377 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
378 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
379 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
380
381 auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
382 sizeof(ADataType) / APackedSize;
383 auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
384 sizeof(BDataType) / BPackedSize;
385
387 arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
388 rotating_mem.Print();
389
390 auto run_flush_cache = [&]() {
391 // flush icache
393 // rotating mem
394 rotating_mem.Next();
395 // clear c mem
396 if(arg_.KBatch > 1)
397 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
398 0,
399 arg_.M * arg_.N * sizeof(CDataType),
400 stream_config.stream_id_));
401 };
402
404 stream_config,
405 run_flush_cache,
406 kernel,
407 dim3(gdx, gdy, gdz),
408 dim3(BlockSize),
409 0,
410 arg_);
411 }
412 else
413 {
414 if(arg.KBatch > 1)
415 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
416 0,
417 arg.M * arg.N * sizeof(CDataType),
418 stream_config.stream_id_));
419
420 ave_time = launch_and_time_kernel(
421 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
422 }
423 };
424
425 constexpr index_t minimum_occupancy =
426 BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
427 ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
428 MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
429 ? 2
430 : 1
431 : 2;
432
433 if(has_main_k_block_loop)
434 {
435 // Tail number always full
436 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
437 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
438 {
439 if(arg.KBatch > 1)
440 {
442 GridwiseGemm,
443 DeviceArgument,
444 true,
446 minimum_occupancy>;
447 Run(kernel);
448 }
449 else
450 {
452 GridwiseGemm,
453 DeviceArgument,
454 true,
456 minimum_occupancy>;
457 Run(kernel);
458 }
459 }
460 // Tail number could be One to Seven
461 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2)
462 {
463 if(arg.KBatch > 1)
464 {
465 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
466 {
468 GridwiseGemm,
469 DeviceArgument,
470 true,
472 minimum_occupancy,
474 Run(kernel);
475 }
476 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
478 {
480 GridwiseGemm,
481 DeviceArgument,
482 true,
484 minimum_occupancy,
486 Run(kernel);
487 }
488
489 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
490 {
491 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
492 {
494 GridwiseGemm,
495 DeviceArgument,
496 true,
498 minimum_occupancy,
500 Run(kernel);
501 }
502 }
503
504 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
505 {
506 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
508 {
510 GridwiseGemm,
511 DeviceArgument,
512 true,
514 minimum_occupancy,
516 Run(kernel);
517 }
518 }
519
520 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
521 {
522 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
524 {
526 GridwiseGemm,
527 DeviceArgument,
528 true,
530 minimum_occupancy,
532 Run(kernel);
533 }
534 }
535
536 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
537 {
538 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
540 {
542 GridwiseGemm,
543 DeviceArgument,
544 true,
546 minimum_occupancy,
548 Run(kernel);
549 }
550 }
551
552 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
553 {
554 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
555 {
557 GridwiseGemm,
558 DeviceArgument,
559 true,
561 minimum_occupancy,
563 Run(kernel);
564 }
565 }
566
567 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
568 {
569 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
571 {
573 GridwiseGemm,
574 DeviceArgument,
575 true,
577 minimum_occupancy,
579 Run(kernel);
580 }
581 }
582 }
583 else
584 {
585 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::One)
586 {
588 GridwiseGemm,
589 DeviceArgument,
590 true,
592 minimum_occupancy,
594 Run(kernel);
595 }
596 else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
598 {
600 GridwiseGemm,
601 DeviceArgument,
602 true,
604 minimum_occupancy,
606 Run(kernel);
607 }
608
609 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 2)
610 {
611 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Two)
612 {
614 GridwiseGemm,
615 DeviceArgument,
616 true,
618 minimum_occupancy,
620 Run(kernel);
621 }
622 }
623
624 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 3)
625 {
626 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
628 {
630 GridwiseGemm,
631 DeviceArgument,
632 true,
634 minimum_occupancy,
636 Run(kernel);
637 }
638 }
639
640 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 4)
641 {
642 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
644 {
646 GridwiseGemm,
647 DeviceArgument,
648 true,
650 minimum_occupancy,
652 Run(kernel);
653 }
654 }
655
656 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 5)
657 {
658 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
660 {
662 GridwiseGemm,
663 DeviceArgument,
664 true,
666 minimum_occupancy,
668 Run(kernel);
669 }
670 }
671
672 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 6)
673 {
674 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Six)
675 {
677 GridwiseGemm,
678 DeviceArgument,
679 true,
681 minimum_occupancy,
683 Run(kernel);
684 }
685 }
686
687 if constexpr(GridwiseGemm::BlockwiseGemmPipe::PrefetchStages > 7)
688 {
689 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
691 {
693 GridwiseGemm,
694 DeviceArgument,
695 true,
697 minimum_occupancy,
699 Run(kernel);
700 }
701 }
702 }
703 }
704 // Tail number could be Odd or Even
705 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
706 {
707 if(arg.KBatch > 1)
708 {
709 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
710 {
712 GridwiseGemm,
713 DeviceArgument,
714 true,
716 minimum_occupancy,
718 Run(kernel);
719 }
720 else
721 {
723 GridwiseGemm,
724 DeviceArgument,
725 true,
727 minimum_occupancy,
729 Run(kernel);
730 }
731 }
732 else
733 {
734 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
735 {
737 GridwiseGemm,
738 DeviceArgument,
739 true,
741 minimum_occupancy,
743 Run(kernel);
744 }
745 else
746 {
748 GridwiseGemm,
749 DeviceArgument,
750 true,
752 minimum_occupancy,
754 Run(kernel);
755 }
756 }
757 }
758 else
759 {
760 if(arg.KBatch > 1)
761 {
762 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
763 {
765 GridwiseGemm,
766 DeviceArgument,
767 true,
769 minimum_occupancy,
771 Run(kernel);
772 }
773 else
774 {
776 GridwiseGemm,
777 DeviceArgument,
778 true,
780 minimum_occupancy,
782 Run(kernel);
783 }
784 }
785 else
786 {
787 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
788 {
790 GridwiseGemm,
791 DeviceArgument,
792 true,
794 minimum_occupancy,
796 Run(kernel);
797 }
798 else
799 {
801 GridwiseGemm,
802 DeviceArgument,
803 true,
805 minimum_occupancy,
807 Run(kernel);
808 }
809 }
810 }
811 }
812 else
813 {
814 // Tail number always 1
815 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
816 {
817 if(arg.KBatch > 1)
818 {
820 GridwiseGemm,
821 DeviceArgument,
822 false,
824 minimum_occupancy>;
825 Run(kernel);
826 }
827 else
828 {
830 GridwiseGemm,
831 DeviceArgument,
832 false,
834 minimum_occupancy>;
835 Run(kernel);
836 }
837 }
838 }
839
840 return ave_time;
841 }
842
843 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
844 {
845 if(get_warp_size() == 64)
846 {
847 if constexpr(NXdlPerWave64 > 0)
848 {
849 return RunImp<GridwiseGemm64>(arg, stream_config);
850 }
851 }
852 else
853 {
854 if constexpr(NXdlPerWave32 > 0)
855 {
856 using Argument32 = ArgumentBase<GridwiseGemm32>;
857 return RunImp<GridwiseGemm32>(reinterpret_cast<const Argument32&>(arg),
858 stream_config);
859 }
860 }
861 return 0;
862 }
863
864 // polymorphic
865 float Run(const BaseArgument* p_arg,
866 const StreamConfig& stream_config = StreamConfig{}) override
867 {
868 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
869 }
870 };
871
872 static constexpr bool IsValidCompilationParameter()
873 {
874 // TODO: properly implement this check
875 return true;
876 }
877
878 static bool IsSupportedArgument(const Argument& arg)
879 {
881 {
882 return false;
883 }
884 if(is_gfx11_supported() && arg.KBatch > 1)
885 {
886 return false;
887 }
888 if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
889 {
890 return false;
891 }
892
893 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
894 GemmSpec == GemmSpecialization::NKPadding ||
895 GemmSpec == GemmSpecialization::MNKPadding ||
896 GemmSpec == GemmSpecialization::KPadding))
897 {
898 return false;
899 }
900 if(get_warp_size() == 64)
901 {
902 if constexpr(NXdlPerWave64 > 0)
903 {
905 }
906 }
907 else
908 {
909 if constexpr(NXdlPerWave32 > 0)
910 {
911 using Argument32 = ArgumentBase<GridwiseGemm32>;
912 return GridwiseGemm32::CheckValidity(reinterpret_cast<const Argument32&>(arg));
913 }
914 }
915 return false;
916 }
917
918 // polymorphic
919 bool IsSupportedArgument(const BaseArgument* p_arg) override
920 {
921 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
922 }
923
924 index_t GetKPerBlock() override { return KPerBlock; }
925
926 bool GetPermuteB() override { return PermuteB; }
927
928 static auto MakeArgument(const ADataType* p_a,
929 const BDataType* p_b,
930 CDataType* p_c,
931 index_t M,
932 index_t N,
933 index_t K,
934 index_t StrideA,
935 index_t StrideB,
936 index_t StrideC,
937 index_t StrideScaleB,
938 index_t BatchStrideA,
939 index_t BatchStrideB,
940 index_t BatchStrideC,
941 index_t BatchStrideScaleB,
942 const BScaleDataType* p_b_scale,
943 index_t Batch,
944 index_t KBatch,
945 AElementwiseOperation a_element_op,
946 BElementwiseOperation b_element_op,
947 CElementwiseOperation c_element_op)
948 {
949 return Argument{p_a,
950 p_b,
951 p_c,
952 M,
953 N,
954 K,
955 StrideA,
956 StrideB,
957 StrideC,
958 StrideScaleB,
959 BatchStrideA,
960 BatchStrideB,
961 BatchStrideC,
962 BatchStrideScaleB,
963 p_b_scale,
964 Batch,
965 KBatch,
966 a_element_op,
967 b_element_op,
968 c_element_op};
969 }
970
971 static auto MakeInvoker() { return Invoker{}; }
972
973 // polymorphic
974 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
975 const void* p_b,
976 void* p_c,
977 index_t M,
978 index_t N,
979 index_t K,
980 index_t StrideA,
981 index_t StrideB,
982 index_t StrideC,
983 index_t StrideScaleB,
984 index_t BatchStrideA,
985 index_t BatchStrideB,
986 index_t BatchStrideC,
987 index_t BatchStrideScaleB,
988 const void* p_b_scale,
989 index_t Batch,
990 index_t KBatch,
991 AElementwiseOperation a_element_op,
992 BElementwiseOperation b_element_op,
993 CElementwiseOperation c_element_op) override
994 {
995 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
996 static_cast<const BDataType*>(p_b),
997 static_cast<CDataType*>(p_c),
998 M,
999 N,
1000 K,
1001 StrideA,
1002 StrideB,
1003 StrideC,
1004 StrideScaleB,
1005 BatchStrideA,
1006 BatchStrideB,
1007 BatchStrideC,
1008 BatchStrideScaleB,
1009 static_cast<const BScaleDataType*>(p_b_scale),
1010 Batch,
1011 KBatch,
1012 a_element_op,
1013 b_element_op,
1014 c_element_op);
1015 }
1016
1017 // polymorphic
1018 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1019 {
1020 return std::make_unique<Invoker>(Invoker{});
1021 }
1022
1023 // polymorphic
1024 std::string GetTypeString() const override
1025 {
1026 auto str = std::stringstream();
1027
1028 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
1031
1032 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
1038
1039 // clang-format off
1040 str << "DeviceGemmXdlUniversal"
1041 << "<"
1042 << getGemmSpecializationString(GemmSpec) << ", "
1043 << std::string(ALayout::name)[0]
1044 << std::string(BLayout::name)[0]
1045 << std::string(CLayout::name)[0]
1046 << ">"
1047 << " BlkSize: "
1048 << BlockSize << ", "
1049 << "BlkTile: "
1050 << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
1051 << "WaveTile: "
1052 << MPerXDL<<"x"<<NPerXDL << ", "
1053 << "WaveMap: "
1054 << MXdlPerWave<<"x" << NXdlPerWave<<", "
1055 << "VmemReadVec: "
1056 << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
1057 << "BlkGemmPipelineScheduler: "
1058 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
1059 << "BlkGemmPipelineVersion: "
1060 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
1061 << "BlkGemmPipelinePrefetchStages: "
1062 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages;
1063 // clang-format on
1064
1065 return str.str();
1066 }
1067};
1068
1069} // namespace device
1070} // namespace tensor_operation
1071} // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ One
Definition blkgemmpipe_scheduler.hpp:37
@ Seven
Definition blkgemmpipe_scheduler.hpp:43
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
@ Four
Definition blkgemmpipe_scheduler.hpp:40
@ Two
Definition blkgemmpipe_scheduler.hpp:38
@ Full
Definition blkgemmpipe_scheduler.hpp:49
@ Three
Definition blkgemmpipe_scheduler.hpp:39
@ Five
Definition blkgemmpipe_scheduler.hpp:41
@ Six
Definition blkgemmpipe_scheduler.hpp:42
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__global__ void kernel_batched_gemm_b_scale_xdl_cshuffle_v3(BatchedGemmArg karg)
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:38
int64_t long_index_t
Definition ck.hpp:300
__global__ void kernel_batched_gemm_b_scale_xdl_cshuffle_v3_2lds(BatchedGemmArg karg)
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:79
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CElementwiseOperation c_element_op=CElementwiseOperation{})
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:717
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_xdl_cshuffle_v3.hpp:247
Definition data_type.hpp:187
Definition device_base.hpp:197
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t BatchStrideScaleB)
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:256
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:272
__host__ __device__ constexpr long_index_t GetSacleBPtrOffset(index_t g_idx) const
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:281
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:267
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:277
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:295
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:297
ArgumentBase(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t StrideScaleB_, index_t BatchStrideA_, index_t BatchStrideB_, index_t BatchStrideC_, index_t BatchStrideScaleB_, const BScaleDataType *p_b_scale_grid_, index_t Batch_, index_t KBatch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:299
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:344
float RunImp(const ArgumentBase< GridwiseGemm > &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:346
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:843
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:865
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:179
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:1018
static constexpr index_t APackedSize
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:240
static auto MakeInvoker()
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:971
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:181
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:919
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t StrideScaleB, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t BatchStrideScaleB, const void *p_b_scale, index_t Batch, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:974
static constexpr index_t BPackedSize
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:247
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t StrideScaleB, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t BatchStrideScaleB, const BScaleDataType *p_b_scale, index_t Batch, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:928
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:238
static constexpr auto NXdlPerWave32
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:182
static constexpr bool IsValidCompilationParameter()
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:872
GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, ScaleBlockN, ScaleBlockK, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemmBase
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:186
index_t GetKPerBlock() override
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:924
std::string GetTypeString() const override
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:1024
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:237
bool GetPermuteB() override
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:926
ArgumentBase< GridwiseGemm64 > Argument
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:340
static bool IsSupportedArgument(const Argument &arg)
Definition device_batched_gemm_xdl_fpAintB_b_scale.hpp:878
Definition device_batched_gemm.hpp:60
Definition flush_cache.hpp:299