gridwise_gemm_xdlops_v2r3.hpp Source File

gridwise_gemm_xdlops_v2r3.hpp Source File#

Composable Kernel: gridwise_gemm_xdlops_v2r3.hpp Source File
gridwise_gemm_xdlops_v2r3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
17
18namespace ck {
19
20template <typename GridwiseGemm,
21 typename FloatAB,
22 typename FloatC,
23 typename AGridDesc_K0_M_K1,
24 typename BGridDesc_K0_N_K1,
25 typename CGridDesc_M_N,
26 bool HasMainKBlockLoop>
27__global__ void
28#if CK_USE_LAUNCH_BOUNDS
30#endif
31#if CK_USE_WAVES_PER_EU
32 __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
33#endif
34 kernel_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
35 const FloatAB* __restrict__ p_b_grid,
36 FloatC* __restrict__ p_c_grid,
37 const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
38 const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
39 const CGridDesc_M_N c_grid_desc_m_n)
40{
41#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
42 defined(__gfx12__)
43 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
44 {
45 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
46
47 GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
48 p_b_grid,
49 p_c_grid,
50 p_shared,
51 a_grid_desc_k0_m_k1,
52 b_grid_desc_k0_n_k1,
53 c_grid_desc_m_n);
54 }
55#else
56 ignore = p_a_grid;
57 ignore = p_b_grid;
58 ignore = p_c_grid;
59 ignore = a_grid_desc_k0_m_k1;
60 ignore = b_grid_desc_k0_n_k1;
61 ignore = c_grid_desc_m_n;
62#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
63}
64
65template <typename GridwiseGemm, bool HasMainKBlockLoop>
66__global__ void
67#if CK_USE_LAUNCH_BOUNDS
69#endif
70#if CK_USE_WAVES_PER_EU
71 __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
72#endif
73 kernel_gemm_xdlops_v2r3(const typename GridwiseGemm::Argument karg)
74{
75#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx11__) || \
76 defined(__gfx12__)
77 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
78 {
79 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
80
81 const auto a_grid_desc_k0_m_k1 =
82 amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1(
83 karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA));
84 const auto b_grid_desc_k0_n_k1 =
85 amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1(
86 karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB));
87 const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N(
88 karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC));
89
90 GridwiseGemm::template Run<HasMainKBlockLoop>(karg.p_a_grid,
91 karg.p_b_grid,
92 karg.p_c_grid,
93 p_shared,
94 a_grid_desc_k0_m_k1,
95 b_grid_desc_k0_n_k1,
96 c_grid_desc_m_n);
97 }
98#else
99 ignore = karg;
100#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
101}
102
103template <index_t BlockSize,
104 typename FloatAB,
105 typename FloatAcc,
106 typename FloatC,
107 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
108 typename AElementwiseOperation,
109 typename BElementwiseOperation,
110 typename CElementwiseOperation,
111 index_t MPerBlock,
112 index_t NPerBlock,
113 index_t K0PerBlock,
114 index_t MPerXdl,
115 index_t NPerXdl,
116 index_t K1Value,
117 index_t MXdlPerWave,
118 index_t NXdlPerWave,
119 typename ABlockTransferThreadClusterLengths_K0_M_K1,
120 typename ABlockTransferThreadClusterArrangeOrder,
121 typename ABlockTransferSrcAccessOrder,
122 index_t ABlockTransferSrcVectorDim,
123 index_t ABlockTransferSrcScalarPerVector,
124 index_t ABlockTransferDstScalarPerVector_K1,
125 bool AThreadTransferSrcResetCoordinateAfterRun,
126 bool ABlockLdsExtraM,
127 typename BBlockTransferThreadClusterLengths_K0_N_K1,
128 typename BBlockTransferThreadClusterArrangeOrder,
129 typename BBlockTransferSrcAccessOrder,
130 index_t BBlockTransferSrcVectorDim,
131 index_t BBlockTransferSrcScalarPerVector,
132 index_t BBlockTransferDstScalarPerVector_K1,
133 bool BThreadTransferSrcResetCoordinateAfterRun,
134 bool BBlockLdsExtraN,
135 typename CThreadTransferSrcDstAccessOrder,
136 index_t CThreadTransferSrcDstVectorDim,
137 index_t CThreadTransferDstScalarPerVector,
138 index_t NumGemmKPrefetchStage = 1,
142{
143 static constexpr auto I0 = Number<0>{};
144 static constexpr auto I1 = Number<1>{};
145 static constexpr auto I2 = Number<2>{};
146 static constexpr auto I3 = Number<3>{};
147 static constexpr auto I4 = Number<4>{};
148 static constexpr auto I5 = Number<5>{};
149 static constexpr auto I6 = Number<6>{};
150 static constexpr auto I7 = Number<7>{};
151
152 // K1 should be Number<...>
153 static constexpr bool is_single_rate_mfma =
155 (is_same<FloatAB, int8_t>::value && K1Value <= 8) ||
157 ? true
158 : false;
159 static constexpr auto is_scale_mfma = false;
160 static constexpr auto K1 = Number<math::max(
161 K1Value,
163 selected_mfma.k_per_blk)>{};
164
166
168
169 __host__ static auto CalculateGridSize(index_t M, index_t N)
170 {
171 return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
172 }
173
174 template <typename CGridDesc_M_N>
175 __host__ static auto CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
176 {
177 return std::make_tuple(Block2CTileMap::CalculateGridSize(c_grid_desc_m_n), 1, 1);
178 }
179
180 template <typename>
181 __host__ static auto CalculateGridSize(index_t M, index_t N)
182 {
183 return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
184 }
185
186 __host__ static auto CalculateMPadded(index_t M)
187 {
188 return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
189 }
190
191 __host__ static auto CalculateNPadded(index_t N)
192 {
193 return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
194 }
195
196 __host__ static auto CalculateK0(index_t K) { return math::integer_divide_ceil(K, K1Value); }
197
198 // Argument
199 struct Problem
200 {
201 __host__ Problem(index_t M_,
202 index_t N_,
203 index_t K_,
204 index_t StrideA_,
205 index_t StrideB_,
206 index_t StrideC_)
207 : M{M_},
208 N{N_},
209 K{K_},
210 StrideA{StrideA_},
211 StrideB{StrideB_},
212 StrideC{StrideC_},
215 K0{CalculateK0(K_)}
216 {
217 }
218
219 __host__ void Print() const
220 {
221 std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
222 << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
223 << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "K0:" << K0
224 << "}" << std::endl;
225 }
226
236 };
237
238 // Argument
240 {
241 __host__ Argument(const ElementDataTypeAB* p_a_grid_,
242 const ElementDataTypeAB* p_b_grid_,
243 FloatC* p_c_grid_,
244 index_t M_,
245 index_t N_,
246 index_t K_,
247 index_t StrideA_,
248 index_t StrideB_,
249 index_t StrideC_)
250 : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
251 p_a_grid{p_a_grid_},
252 p_b_grid{p_b_grid_},
253 p_c_grid{p_c_grid_}
254 {
255 }
256
259 FloatC* p_c_grid;
260 };
261
264
265 // denorm test fix, required to work around fp16 mfma issue
266 // we convert fp16->fp32->bf16 and execute bf16 mfma instruction
267 // when mfma if fixed, remove this section and update
268 // FloatABAdjusted -> FloatAB throughout this file
269#if CK_GFX90A_DENORM_WORKAROUND
271#else
272 using FloatABAdjusted = FloatAB;
273#endif
274
275 __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
276 {
277 constexpr auto max_lds_align = K1;
278
279 // A matrix in LDS memory, dst of blockwise copy
280 constexpr auto a_block_desc_k0_m_k1 = [&]() {
281 if constexpr(ABlockLdsExtraM)
282 {
286 }
287 else
288 {
290 make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
291 }
292 }();
293
294 return a_block_desc_k0_m_k1;
295 }
296
297 __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
298 {
299 constexpr auto max_lds_align = K1;
300
301 // B matrix in LDS memory, dst of blockwise copy
302 constexpr auto b_block_desc_k0_n_k1 = [&]() {
303 if constexpr(BBlockLdsExtraN)
304 {
308 }
309 else
310 {
312 make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
313 }
314 }();
315
316 return b_block_desc_k0_n_k1;
317 }
318
319 __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
320 {
321 // LDS allocation for A and B: be careful of alignment
322 constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
323
324 constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
325
326 constexpr auto max_lds_align = K1;
327
328 constexpr auto a_block_space_size_aligned =
329 math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
330
331 constexpr auto b_block_space_size_aligned =
332 math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
333
334 return (a_block_space_size_aligned + b_block_space_size_aligned) *
335 sizeof(ElementDataTypeAB);
336 }
337
338 template <
339 InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
340 __device__ static bool constexpr IsValidCompilationParameter()
341 {
342 return ck::tensor_operation::device::IsValidGemmCompilationParameter<
343 BlockSize,
344 MPerBlock,
345 NPerBlock,
346 MPerXdl,
347 NPerXdl,
348 MXdlPerWave,
349 NXdlPerWave,
350 FloatC,
351 CGlobalMemoryDataOperation>();
352 }
353
354 template <typename AGridDesc_K0_M_K1, typename BGridDesc_K0_N_K1, typename CGridDesc_M_N>
355 __host__ __device__ static constexpr bool
356 CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
357 const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
358 const CGridDesc_M_N& c_grid_desc_m_n)
359 {
360 static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
361 "wrong! K1 need to be known at compile-time");
362
363 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
364 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
365 "Invalid tuning param!");
366
367 const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
368 const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
369 const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
370
371 if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
372 K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
373 K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
374 return false;
375
376 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
377 return false;
378
379 // check gridwise gemm pipeline
380 const auto num_k_loop = K0 / K0PerBlock;
381
382 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
383 {
384 return false;
385 }
386
387 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
388 return true;
389 }
390
391 __host__ static constexpr bool CheckValidity(const Problem& problem)
392 {
393 static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
394 "wrong! K1 need to be known at compile-time");
395
396 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
397 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
398 "Invalid tuning param!");
399
400 // check gridwise gemm pipeline
401 const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock);
402 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
403 {
404 return false;
405 }
406
407 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
408 return true;
409 }
410
411 __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
412 {
413 const index_t num_loop = math::integer_divide_ceil(K, K0PerBlock * K1);
414
415 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
416 }
417
418 template <typename CGridDesc>
419 __host__ __device__ static constexpr auto
420 MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc& c_grid_desc_m_n)
421 {
422 constexpr auto max_lds_align = K1;
423
424 // A matrix in LDS memory, dst of blockwise copy
425 constexpr auto a_block_desc_k0_m_k1 = [&]() {
426 if constexpr(ABlockLdsExtraM)
427 {
431 }
432 else
433 {
435 make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
436 }
437 }();
438
439 // B matrix in LDS memory, dst of blockwise copy
440 constexpr auto b_block_desc_k0_n_k1 = [&]() {
441 if constexpr(BBlockLdsExtraN)
442 {
446 }
447 else
448 {
450 make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
451 }
452 }();
453
454 using BlockwiseGemm =
458 FloatAcc,
459 decltype(a_block_desc_k0_m_k1),
460 decltype(b_block_desc_k0_n_k1),
461 MPerXdl,
462 NPerXdl,
463 MXdlPerWave,
464 NXdlPerWave,
465 K1,
468
469 return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
470 }
471
472 // return block_id to C matrix tile idx (m0, n0) mapping
474
475 template <bool HasMainKBlockLoop,
476 typename AGridDesc_K0_M_K1,
477 typename BGridDesc_K0_N_K1,
478 typename CGridDesc_M_N>
479 __device__ static void Run(const ElementDataTypeAB* p_a_grid,
480 const ElementDataTypeAB* p_b_grid,
481 FloatC* p_c_grid,
482 void* __restrict__ p_shared,
483 const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
484 const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
485 const CGridDesc_M_N& c_grid_desc_m_n)
486 {
487 const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
489
490 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
491 p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
492 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
493 p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
495 p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
496
497 const AElementwiseOperation a_element_op{};
498 const BElementwiseOperation b_element_op{};
499 const CElementwiseOperation c_element_op{};
500
501 const auto block_2_ctile_map =
502 Block2CTileMap{c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)};
503
504 // divide block work by [M, N]
505 const auto block_work_idx =
506 block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
507
508 if(!block_2_ctile_map.ValidCTileIndex(
509 block_work_idx,
510 make_tuple(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0),
511 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1))))
512 {
513 return;
514 }
515
516 // HACK: this force m/n_block_data_idx_on_grid into SGPR
517 const index_t m_block_data_idx_on_grid =
518 __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
519
520 const index_t n_block_data_idx_on_grid =
521 __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
522
523 // lds max alignment
524 constexpr auto max_lds_align = K1;
525
526 // A matrix in LDS memory, dst of blockwise copy
527 constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
528
529 // B matrix in LDS memory, dst of blockwise copy
530 constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
531
532 // A matrix blockwise copy
533 auto a_blockwise_copy =
535 AElementwiseOperation,
539 ABlockTransferThreadClusterLengths_K0_M_K1,
540 ABlockTransferThreadClusterArrangeOrder,
543 decltype(a_grid_desc_k0_m_k1),
544 decltype(a_block_desc_k0_m_k1),
545 ABlockTransferSrcAccessOrder,
547 ABlockTransferSrcVectorDim,
548 2,
549 ABlockTransferSrcScalarPerVector,
550 ABlockTransferDstScalarPerVector_K1,
551 1,
552 1,
553 AThreadTransferSrcResetCoordinateAfterRun,
554 true,
555 NumGemmKPrefetchStage>(
556 a_grid_desc_k0_m_k1,
557 make_multi_index(0, m_block_data_idx_on_grid, 0),
558 a_element_op,
559 a_block_desc_k0_m_k1,
560 make_multi_index(0, 0, 0),
562
563 // B matrix blockwise copy
564 auto b_blockwise_copy =
566 BElementwiseOperation,
570 BBlockTransferThreadClusterLengths_K0_N_K1,
571 BBlockTransferThreadClusterArrangeOrder,
574 decltype(b_grid_desc_k0_n_k1),
575 decltype(b_block_desc_k0_n_k1),
576 BBlockTransferSrcAccessOrder,
578 BBlockTransferSrcVectorDim,
579 2,
580 BBlockTransferSrcScalarPerVector,
581 BBlockTransferDstScalarPerVector_K1,
582 1,
583 1,
584 BThreadTransferSrcResetCoordinateAfterRun,
585 true,
586 NumGemmKPrefetchStage>(
587 b_grid_desc_k0_n_k1,
588 make_multi_index(0, n_block_data_idx_on_grid, 0),
589 b_element_op,
590 b_block_desc_k0_n_k1,
591 make_multi_index(0, 0, 0),
593
594 // GEMM definition
595 // c_mtx += transpose(a_mtx) * b_mtx
596 // a_mtx[K0PerBlock, MPerBlock] is in LDS
597 // b_mtx[K0PerBlock, NPerBlock] is in LDS
598 // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
599 // register
600 // sanity check
602 BlockSize,
605 FloatAcc,
606 decltype(a_block_desc_k0_m_k1),
607 decltype(b_block_desc_k0_n_k1),
608 MPerXdl,
609 NPerXdl,
610 MXdlPerWave,
611 NXdlPerWave,
612 K1,
613 LoopSched,
616
617 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
618
619 // LDS allocation for A and B: be careful of alignment
620 constexpr auto a_block_space_size_aligned =
621 math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
622
624 static_cast<ElementDataTypeAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
625
627 static_cast<ElementDataTypeAB*>(p_shared) + a_block_space_size_aligned,
628 b_block_desc_k0_n_k1.GetElementSpaceSize());
629
630 constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
631 constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
632
633 // gridwise GEMM pipeline
634 const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
635 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
636
637 GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
638 a_block_desc_k0_m_k1,
639 a_blockwise_copy,
640 a_grid_buf,
641 a_block_buf,
642 a_block_slice_copy_step,
643 b_grid_desc_k0_n_k1,
644 b_block_desc_k0_n_k1,
645 b_blockwise_copy,
646 b_grid_buf,
647 b_block_buf,
648 b_block_slice_copy_step,
649 blockwise_gemm,
650 c_thread_buf,
651 num_k_block_main_loop);
652
653 // output: register to global memory
654 {
655 constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
656 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
657
658 constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
659 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
660
661 constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I0);
662 constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I1);
663 constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I2);
664 constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I3);
665 constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I4);
666 constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I5);
667 constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I6);
668 constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(I7);
669
670 // calculate origin of thread output tensor on global memory
671 // blockwise GEMM c matrix starting index
672 const auto c_thread_mtx_on_block =
673 blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
674
675 const index_t m_thread_data_on_grid =
676 m_block_data_idx_on_grid + c_thread_mtx_on_block[I0];
677
678 const index_t n_thread_data_on_grid =
679 n_block_data_idx_on_grid + c_thread_mtx_on_block[I1];
680
681 const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
683 make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
686
687 const auto m_thread_data_on_grid_idx =
688 m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
689 make_multi_index(m_thread_data_on_grid));
690
691 const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor(
695
696 const auto n_thread_data_on_grid_idx =
697 n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
698 make_multi_index(n_thread_data_on_grid));
699
700 auto c_thread_copy =
702 FloatC,
703 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
704 decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
705 CElementwiseOperation,
707 CThreadTransferSrcDstAccessOrder,
708 CThreadTransferSrcDstVectorDim,
709 CThreadTransferDstScalarPerVector,
710 CGlobalMemoryDataOperation,
711 1,
712 true>{
713 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
714 make_multi_index(m_thread_data_on_grid_idx[I0],
715 n_thread_data_on_grid_idx[I0],
716 m_thread_data_on_grid_idx[I1],
717 n_thread_data_on_grid_idx[I1],
718 m_thread_data_on_grid_idx[I2],
719 m_thread_data_on_grid_idx[I3],
720 m_thread_data_on_grid_idx[I4],
721 n_thread_data_on_grid_idx[I2]),
722 c_element_op};
723
724 c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
725 make_tuple(I0, I0, I0, I0, I0, I0, I0, I0),
726 c_thread_buf,
727 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
728 c_grid_buf);
729 }
730 }
731};
732
733template <index_t BlockSize,
734 typename FloatAB,
735 typename FloatAcc,
736 typename FloatC,
737 InMemoryDataOperationEnum CGlobalMemoryDataOperation,
738 typename ALayout,
739 typename BLayout,
740 typename CLayout,
741 typename AElementwiseOperation,
742 typename BElementwiseOperation,
743 typename CElementwiseOperation,
745 index_t MPerBlock,
746 index_t NPerBlock,
747 index_t K0PerBlock,
748 index_t MPerXdl,
749 index_t NPerXdl,
750 index_t K1Value,
751 index_t MXdlPerWave,
752 index_t NXdlPerWave,
753 typename ABlockTransferThreadClusterLengths_K0_M_K1,
754 typename ABlockTransferThreadClusterArrangeOrder,
755 typename ABlockTransferSrcAccessOrder,
756 index_t ABlockTransferSrcVectorDim,
757 index_t ABlockTransferSrcScalarPerVector,
758 index_t ABlockTransferDstScalarPerVector_K1,
759 bool AThreadTransferSrcResetCoordinateAfterRun,
760 bool ABlockLdsExtraM,
761 typename BBlockTransferThreadClusterLengths_K0_N_K1,
762 typename BBlockTransferThreadClusterArrangeOrder,
763 typename BBlockTransferSrcAccessOrder,
764 index_t BBlockTransferSrcVectorDim,
765 index_t BBlockTransferSrcScalarPerVector,
766 index_t BBlockTransferDstScalarPerVector_K1,
767 bool BThreadTransferSrcResetCoordinateAfterRun,
768 bool BBlockLdsExtraN,
769 typename CThreadTransferSrcDstAccessOrder,
770 index_t CThreadTransferSrcDstVectorDim,
771 index_t CThreadTransferDstScalarPerVector,
772 index_t NumGemmKPrefetchStage = 1,
777 FloatAB,
778 FloatAcc,
779 FloatC,
780 CGlobalMemoryDataOperation,
781 AElementwiseOperation,
782 BElementwiseOperation,
783 CElementwiseOperation,
784 MPerBlock,
785 NPerBlock,
786 K0PerBlock,
787 MPerXdl,
788 NPerXdl,
789 K1Value,
790 MXdlPerWave,
791 NXdlPerWave,
792 ABlockTransferThreadClusterLengths_K0_M_K1,
793 ABlockTransferThreadClusterArrangeOrder,
794 ABlockTransferSrcAccessOrder,
795 ABlockTransferSrcVectorDim,
796 ABlockTransferSrcScalarPerVector,
797 ABlockTransferDstScalarPerVector_K1,
798 AThreadTransferSrcResetCoordinateAfterRun,
799 ABlockLdsExtraM,
800 BBlockTransferThreadClusterLengths_K0_N_K1,
801 BBlockTransferThreadClusterArrangeOrder,
802 BBlockTransferSrcAccessOrder,
803 BBlockTransferSrcVectorDim,
804 BBlockTransferSrcScalarPerVector,
805 BBlockTransferDstScalarPerVector_K1,
806 BThreadTransferSrcResetCoordinateAfterRun,
807 BBlockLdsExtraN,
808 CThreadTransferSrcDstAccessOrder,
809 CThreadTransferSrcDstVectorDim,
810 CThreadTransferDstScalarPerVector,
811 NumGemmKPrefetchStage,
812 LoopSched,
813 PipelineVer>
814{
815 using Parent =
817 FloatAB,
818 FloatAcc,
819 FloatC,
820 CGlobalMemoryDataOperation,
821 AElementwiseOperation,
822 BElementwiseOperation,
823 CElementwiseOperation,
824 MPerBlock,
825 NPerBlock,
826 K0PerBlock,
827 MPerXdl,
828 NPerXdl,
829 K1Value,
830 MXdlPerWave,
831 NXdlPerWave,
832 ABlockTransferThreadClusterLengths_K0_M_K1,
833 ABlockTransferThreadClusterArrangeOrder,
834 ABlockTransferSrcAccessOrder,
835 ABlockTransferSrcVectorDim,
836 ABlockTransferSrcScalarPerVector,
837 ABlockTransferDstScalarPerVector_K1,
838 AThreadTransferSrcResetCoordinateAfterRun,
839 ABlockLdsExtraM,
840 BBlockTransferThreadClusterLengths_K0_N_K1,
841 BBlockTransferThreadClusterArrangeOrder,
842 BBlockTransferSrcAccessOrder,
843 BBlockTransferSrcVectorDim,
844 BBlockTransferSrcScalarPerVector,
845 BBlockTransferDstScalarPerVector_K1,
846 BThreadTransferSrcResetCoordinateAfterRun,
847 BBlockLdsExtraN,
848 CThreadTransferSrcDstAccessOrder,
849 CThreadTransferSrcDstVectorDim,
850 CThreadTransferDstScalarPerVector,
851 NumGemmKPrefetchStage,
852 LoopSched,
853 PipelineVer>;
854
855 using typename Parent::GridwiseGemmPipe;
856 using typename Parent::Problem;
857
858 using Parent::I1;
859
860 using Parent::K1;
861
862 __device__ static auto
864 {
865 const auto a_grid_desc_m_k = [&]() {
867 {
868 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
869 }
871 {
872 return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
873 }
874 }();
875
877 {
878 const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock;
879 const auto KPad = K0Pad * K1Value;
880
881 const auto a_grid_desc_m_kpad = transform_tensor_descriptor(
882 a_grid_desc_m_k,
886
888 a_grid_desc_m_kpad,
890 make_right_pad_transform(M, MPad - M)),
893 }
895 {
897 a_grid_desc_m_k,
899 make_right_pad_transform(M, MPad - M)),
902 }
903 else
904 {
906 a_grid_desc_m_k,
911 }
912 }
913
914 __device__ static auto
916 {
917 const auto b_grid_desc_k_n = [&]() {
919 {
920 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
921 }
923 {
924 return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
925 }
926 }();
927
929 {
930 const auto K0Pad = math::integer_divide_ceil(K0, K0PerBlock) * K0PerBlock;
931 const auto KPad = K0Pad * K1Value;
932
933 const auto b_grid_desc_kpad_n = transform_tensor_descriptor(
934 b_grid_desc_k_n,
938
940 b_grid_desc_kpad_n,
942 make_right_pad_transform(N, NPad - N)),
945 }
946
948 {
950 b_grid_desc_k_n,
952 make_right_pad_transform(N, NPad - N)),
955 }
956 else
957 {
959 b_grid_desc_k_n,
964 }
965 }
966
967 __device__ static auto
969 {
970 const auto c_grid_desc_m_n = [&]() {
972 {
973 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
974 }
976 {
977 return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
978 }
979 }();
980
983 {
984 return transform_tensor_descriptor(c_grid_desc_m_n,
986 make_right_pad_transform(N, NPad - N)),
989 }
990 else
991 {
992
994 c_grid_desc_m_n,
998 }
999 }
1000
1002
1003 __host__ static constexpr bool CheckValidity(const Problem& problem)
1004 {
1005 static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
1006 "wrong! K1 need to be known at compile-time");
1007
1008 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1009 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1010 "Invalid tuning param!");
1011
1012 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1016 {
1017 if(!(problem.M % MPerBlock == 0))
1018 {
1019 return false;
1020 }
1021 }
1022
1023 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1027 {
1028 if(!(problem.N % NPerBlock == 0))
1029 {
1030 return false;
1031 }
1032 }
1033
1034 if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1038 {
1039 if(!(problem.K0 % K0PerBlock == 0))
1040 {
1041 return false;
1042 }
1043 }
1044
1046 {
1047 if(problem.K % ABlockTransferSrcScalarPerVector != 0)
1048 {
1049 return false;
1050 }
1051 }
1052 else
1053 {
1054 if(problem.M % ABlockTransferSrcScalarPerVector != 0)
1055 {
1056 return false;
1057 }
1058 }
1059
1061 {
1062 if(problem.N % BBlockTransferSrcScalarPerVector != 0)
1063 {
1064 return false;
1065 }
1066 }
1067 else
1068 {
1069 if(problem.K % BBlockTransferSrcScalarPerVector != 0)
1070 {
1071 return false;
1072 }
1073 }
1074
1075 // check gridwise gemm pipeline
1076 const auto num_k_loop = math::integer_divide_ceil(problem.K0, K0PerBlock);
1077
1078 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
1079 {
1080 return false;
1081 }
1082
1083 // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1084 return true;
1085 }
1086};
1087
1088} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition device_base.hpp:178
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ NKPadding
Definition gemm_specialization.hpp:19
Definition ck.hpp:268
ushort bhalf_t
Definition data_type.hpp:30
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition blockwise_gemm_xdlops.hpp:620
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
constexpr auto GridwiseGemmPipeline_Selector()
Definition gridwise_gemm_pipeline_selector.hpp:31
int32_t index_t
Definition ck.hpp:299
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition tensor_description/tensor_adaptor.hpp:425
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition tensor_descriptor_helper.hpp:132
__global__ void kernel_gemm_xdlops_v2r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDesc_M_N c_grid_desc_m_n)
Definition gridwise_gemm_xdlops_v2r3.hpp:34
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
Definition block_to_ctile_map.hpp:261
Definition blockwise_gemm_smfmac_xdlops.hpp:44
__host__ Argument(const ElementDataTypeAB *p_a_grid_, const ElementDataTypeAB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition gridwise_gemm_xdlops_v2r3.hpp:241
const ElementDataTypeAB * p_b_grid
Definition gridwise_gemm_xdlops_v2r3.hpp:258
const ElementDataTypeAB * p_a_grid
Definition gridwise_gemm_xdlops_v2r3.hpp:257
FloatC * p_c_grid
Definition gridwise_gemm_xdlops_v2r3.hpp:259
Definition gridwise_gemm_xdlops_v2r3.hpp:200
index_t NPadded
Definition gridwise_gemm_xdlops_v2r3.hpp:234
index_t K
Definition gridwise_gemm_xdlops_v2r3.hpp:229
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition gridwise_gemm_xdlops_v2r3.hpp:201
index_t StrideB
Definition gridwise_gemm_xdlops_v2r3.hpp:231
index_t N
Definition gridwise_gemm_xdlops_v2r3.hpp:228
index_t StrideC
Definition gridwise_gemm_xdlops_v2r3.hpp:232
__host__ void Print() const
Definition gridwise_gemm_xdlops_v2r3.hpp:219
index_t StrideA
Definition gridwise_gemm_xdlops_v2r3.hpp:230
index_t MPadded
Definition gridwise_gemm_xdlops_v2r3.hpp:233
index_t M
Definition gridwise_gemm_xdlops_v2r3.hpp:227
index_t K0
Definition gridwise_gemm_xdlops_v2r3.hpp:235
Definition gridwise_gemm_xdlops_v2r3.hpp:814
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, ADataType, AccDataType, CDataType, CGlobalMemoryDataOperation, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXdl, NPerXdl, K1Value, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, AThreadTransferSrcResetCoordinateAfterRun, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsExtraN, Sequence< 2, 3, 0, 1, 7, 5, 4, 6 >, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, NumGemmKPrefetchStage, LoopSched, PipelineVer > Parent
Definition gridwise_gemm_xdlops_v2r3.hpp:815
Definition gridwise_gemm_xdlops_v2r3.hpp:142
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition xdlops_gemm.hpp:1208
Definition utility/sequence.hpp:43
Blockwise data transfer.
Definition thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition threadwise_tensor_slice_transfer.hpp:39
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition is_known_at_compile_time.hpp:14
Definition device_base.hpp:197
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340