device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp Source File

device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp Source File#

Composable Kernel: device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp Source File
device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
20
21namespace ck {
22
23template <typename GridwiseGemm,
24 typename EMeanVarDataType,
25 bool HasMainKBlockLoop,
26 InMemoryDataOperationEnum EGlobalMemoryDataOperation,
27 index_t MinimumOccupancy = 1,
29__global__ void
30#if CK_USE_LAUNCH_BOUNDS
31__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
32#endif
34 typename GridwiseGemm::Argument karg,
35 EMeanVarDataType* __restrict__ p_welford_mean_grid,
36 EMeanVarDataType* __restrict__ p_welford_var_grid,
37 int32_t* __restrict__ p_welford_count_grid)
38{
39#if(defined(__gfx11__) || defined(__gfx12__))
40#if defined(__gfx11__)
41 // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
42 using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
43 if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
44 (std::is_same_v<e_data_type, ck::half_t> ||
45 std::is_same_v<e_data_type, ck::bhalf_t>)))
46 {
47#endif
48 constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
49 typename GridwiseGemm::EpilogueWelfordCShuffle>();
50
51 __shared__ char p_shared[LDS_size];
52
53 auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
54
55 auto epilogue_args = typename GridwiseGemm::EpilogueWelfordCShuffle(
56 p_welford_mean_grid, p_welford_var_grid, p_welford_count_grid, karg.M, karg.N);
57
58 GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
59 p_shared, splitk_batch_offset, karg, epilogue_args);
60
61#if defined(__gfx11__)
62 }
63#endif
64#else
65 ignore = karg;
66 ignore = p_welford_mean_grid;
67 ignore = p_welford_var_grid;
68 ignore = p_welford_count_grid;
69#endif
70}
71
72template <typename GridwiseWelfordLayernorm,
73 typename EMeanVarDataType,
74 typename HDataType,
75 typename GammaDataType,
76 typename BetaDataType,
77 typename ComputeDataType,
78 typename EHGridDesc_M_N,
79 typename LayernormMeanVarGridDesc_M_NBlock,
80 typename LayernormCountGridDesc_M_NBlock,
81 typename GammaBetaGridDesc_N,
82 typename HElementwiseOperation>
83__global__ void
84#if CK_USE_LAUNCH_BOUNDS
86#endif
88 const EMeanVarDataType* __restrict__ p_e_grid,
89 const EMeanVarDataType* __restrict__ p_in_welford_mean_grid,
90 const EMeanVarDataType* __restrict__ p_in_welford_var_grid,
91 const int32_t* __restrict__ p_in_welford_count_grid,
92 const GammaDataType* __restrict__ p_gamma_grid,
93 const BetaDataType* __restrict__ p_beta_grid,
94 HDataType* __restrict__ p_h_grid,
95 const EHGridDesc_M_N e_grid_desc_m_n,
96 const EHGridDesc_M_N h_grid_desc_m_n,
97 const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock,
98 const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock,
99 const GammaBetaGridDesc_N gamma_grid_desc_n,
100 const GammaBetaGridDesc_N beta_grid_desc_n,
101 index_t numMeanVarCountBlockTileIteration_N,
102 index_t NBlockClusterLength,
103 ComputeDataType epsilon,
104 HElementwiseOperation h_element_op)
105{
106 GridwiseWelfordLayernorm::Run(p_e_grid,
107 p_in_welford_mean_grid,
108 p_in_welford_var_grid,
109 p_in_welford_count_grid,
110 p_gamma_grid,
111 p_beta_grid,
112 p_h_grid,
113 e_grid_desc_m_n,
114 h_grid_desc_m_n,
115 mean_var_grid_desc_m_nblock,
116 count_grid_desc_m_nblock,
117 gamma_grid_desc_n,
118 beta_grid_desc_n,
119 numMeanVarCountBlockTileIteration_N,
120 NBlockClusterLength,
121 epsilon,
122 h_element_op);
123}
124
125} // namespace ck
126
127namespace ck {
128namespace tensor_operation {
129namespace device {
130
131template <typename ALayout,
132 typename BLayout,
133 typename DsLayout,
134 typename HLayout,
135 typename ADataType,
136 typename BDataType,
137 typename DsDataType,
138 typename HDataType,
139 typename AccDataType,
140 typename CShuffleDataType,
141 typename EMeanVarDataType, // LayerNorm
142 typename GammaDataType, // LayerNorm
143 typename BetaDataType, // LayerNorm
144 typename AElementwiseOperation,
145 typename BElementwiseOperation,
146 typename CDEElementwiseOperation,
147 typename HElementwiseOperation,
148 GemmSpecialization GemmSpec,
149 index_t BlockSize,
150 index_t MPerBlock,
151 index_t NPerBlock,
152 index_t KPerBlock,
153 index_t AK1,
154 index_t BK1,
155 index_t MPerWmma,
156 index_t NPerWmma,
157 index_t MRepeat,
158 index_t NRepeat,
159 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
160 typename ABlockTransferThreadClusterArrangeOrder,
161 typename ABlockTransferSrcAccessOrder,
162 index_t ABlockTransferSrcVectorDim,
163 index_t ABlockTransferSrcScalarPerVector,
164 index_t ABlockTransferDstScalarPerVector_AK1,
165 bool ABlockLdsExtraM,
166 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
167 typename BBlockTransferThreadClusterArrangeOrder,
168 typename BBlockTransferSrcAccessOrder,
169 index_t BBlockTransferSrcVectorDim,
170 index_t BBlockTransferSrcScalarPerVector,
171 index_t BBlockTransferDstScalarPerVector_BK1,
172 bool BBlockLdsExtraN,
173 index_t CShuffleMRepeatPerShuffle,
174 index_t CShuffleNRepeatPerShuffle,
175 typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
176 index_t CDEShuffleBlockTransferScalarPerVector,
177 typename LayernormThreadClusterSize_M_N,
178 index_t LayernormThreadSliceSize_M,
181 typename ComputeTypeA = HDataType,
182 typename ComputeTypeB = ComputeTypeA,
183 bool PermuteA = false,
184 bool PermuteB = false>
186 : public DeviceGemmMultipleDLayernorm<ALayout,
187 BLayout,
188 DsLayout,
189 HLayout,
190 ADataType,
191 BDataType,
192 DsDataType,
193 GammaDataType,
194 BetaDataType,
195 HDataType,
196 AElementwiseOperation,
197 BElementwiseOperation,
198 CDEElementwiseOperation,
199 HElementwiseOperation>
200{
201 // EDataType, MeanDataType and VarDataType must be the same.
203
204 static constexpr index_t NumDTensor = DsDataType::Size();
205 static constexpr index_t LayernormHDstVectorSize = CDEShuffleBlockTransferScalarPerVector;
206 static constexpr index_t LayernormGammaSrcVectorSize = CDEShuffleBlockTransferScalarPerVector;
207 static constexpr index_t LayernormBetaSrcVectorSize = CDEShuffleBlockTransferScalarPerVector;
208 static constexpr index_t LayernormESrcVectorSize = CDEShuffleBlockTransferScalarPerVector;
209 static constexpr index_t LayernormThreadSliceSize_N = CDEShuffleBlockTransferScalarPerVector;
210
212 Sequence<LayernormThreadClusterSize_M_N::At(0) * LayernormThreadSliceSize_M,
213 LayernormThreadClusterSize_M_N::At(1) * LayernormThreadSliceSize_N>;
214
215 static constexpr auto I0 = Number<0>{};
216 static constexpr auto I1 = Number<1>{};
217 static constexpr auto I2 = Number<2>{};
218 static constexpr auto I3 = Number<3>{};
219
221 Sequence<CDEShuffleBlockTransferScalarPerVector,
222 CDEShuffleBlockTransferScalarPerVector,
223 CDEShuffleBlockTransferScalarPerVector>;
224
225 // GEMM + Welford 1st part kernel
227 ALayout,
228 BLayout,
229 DsLayout,
230 HLayout,
233 AccDataType,
234 CShuffleDataType,
235 DsDataType,
236 EMeanVarDataType,
237 AElementwiseOperation,
238 BElementwiseOperation,
239 CDEElementwiseOperation,
240 GemmSpec,
241 BlockSize,
242 MPerBlock,
243 NPerBlock,
244 KPerBlock,
245 AK1,
246 BK1,
247 MPerWmma,
248 NPerWmma,
249 MRepeat,
250 NRepeat,
251 ABlockTransferThreadClusterLengths_AK0_M_AK1,
252 ABlockTransferThreadClusterArrangeOrder,
253 ABlockTransferSrcAccessOrder,
254 ABlockTransferSrcVectorDim,
255 ABlockTransferSrcScalarPerVector,
256 ABlockTransferDstScalarPerVector_AK1,
257 false,
258 ABlockLdsExtraM,
259 BBlockTransferThreadClusterLengths_BK0_N_BK1,
260 BBlockTransferThreadClusterArrangeOrder,
261 BBlockTransferSrcAccessOrder,
262 BBlockTransferSrcVectorDim,
263 BBlockTransferSrcScalarPerVector,
264 BBlockTransferDstScalarPerVector_BK1,
265 false,
266 BBlockLdsExtraN,
267 CShuffleMRepeatPerShuffle,
268 CShuffleNRepeatPerShuffle,
269 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
271 BlkGemmPipeSched,
272 BlkGemmPipelineVer,
273 ComputeTypeA,
274 ComputeTypeB,
275 PermuteA,
276 PermuteB>;
277
278 // Welford 2nd part kernel
279 template <typename DoPads, index_t MPerTile, index_t NPerTile>
281 {
282 // Only support row major for E and H
283 const auto grid_desc_m_n =
285 return PadTensorDescriptor(grid_desc_m_n, make_tuple(MPerTile, NPerTile), DoPads{});
286 }
287
288 template <index_t XPerTile>
290 {
291 const auto grid_desc_x = make_naive_tensor_descriptor_packed(make_tuple(X));
292 return PadTensorDescriptor(grid_desc_x, make_tuple(XPerTile), Sequence<true>{});
293 }
294
296 decltype(GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeMeanVarDescriptor_M_N<
300
302 decltype(GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeCountDescriptor_M_N<
306
309
312 HDataType,
313 GammaDataType,
314 BetaDataType,
315 AccDataType,
320 HElementwiseOperation,
321 BlockSize,
322 LayernormThreadClusterSize_M_N::At(I0),
323 LayernormThreadClusterSize_M_N::At(I1),
324 LayernormThreadSliceSize_M,
330
331 // Argument
332 struct Argument : public BaseArgument
333 {
334 Argument(const void* p_a_grid,
335 const void* p_b_grid,
336 std::array<const void*, NumDTensor> p_ds_grid,
337 const void* p_gamma_grid,
338 const void* p_beta_grid,
339 void* p_h_grid,
340 index_t MRaw,
341 index_t NRaw,
342 index_t KRaw,
343 index_t StrideA,
344 index_t StrideB,
345 std::array<index_t, NumDTensor> StrideDs,
346 index_t StrideH,
347 double epsilon,
348 AElementwiseOperation a_element_op,
349 BElementwiseOperation b_element_op,
350 CDEElementwiseOperation cde_element_op,
351 HElementwiseOperation h_element_op)
352 : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
353 p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
354 p_ds_grid_{},
355 p_workspace_e_grid_{nullptr},
356 p_workspace_mean_{nullptr},
357 p_workspace_var_{nullptr},
358 p_workspace_count_{nullptr},
359 p_gamma_grid_{static_cast<const GammaDataType*>(p_gamma_grid)},
360 p_beta_grid_{static_cast<const BetaDataType*>(p_beta_grid)},
361 p_h_grid_{static_cast<HDataType*>(p_h_grid)},
366 MRaw, NRaw, StrideH)},
377 MRaw, NRaw, StrideH)},
378 a_element_op_{a_element_op},
379 b_element_op_{b_element_op},
380 cde_element_op_{cde_element_op},
381 h_element_op_{h_element_op},
382 MRaw_{MRaw},
383 NRaw_{NRaw},
384 KRaw_{KRaw},
385 StrideA_{StrideA},
386 StrideB_{StrideB},
387 StrideDs_{StrideDs},
388 StrideH_{StrideH},
389 gemm_nblock_{math::integer_divide_ceil(NRaw, NPerBlock)},
390 epsilon_{static_cast<AccDataType>(epsilon)}
391 {
392 static_for<0, NumDTensor, 1>{}([&](auto i) { p_ds_grid_[i] = p_ds_grid[i]; });
393
395 GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeMeanVarDescriptor_M_N<
399
401 GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeCountDescriptor_M_N<
405 }
406
407 // pointers
408 const ADataType* p_a_grid_;
409 const BDataType* p_b_grid_;
410 std::array<const void*, NumDTensor> p_ds_grid_;
415 const GammaDataType* p_gamma_grid_;
416 const BetaDataType* p_beta_grid_;
417 HDataType* p_h_grid_;
418
419 // tensor descriptors (Welford second half)
426
427 // element-wise op
428 AElementwiseOperation a_element_op_;
429 BElementwiseOperation b_element_op_;
430 CDEElementwiseOperation cde_element_op_;
431 HElementwiseOperation h_element_op_;
432
438 std::array<index_t, NumDTensor> StrideDs_;
441 AccDataType epsilon_;
442 };
443
444 // Invoker
445 struct Invoker : public BaseInvoker
446 {
447 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
448 {
449 typename GridwiseGemmWelford::Argument gemm_arg{
450 std::array<const void*, 1>{arg.p_a_grid_},
451 std::array<const void*, 1>{arg.p_b_grid_},
452 arg.p_ds_grid_,
453 static_cast<EMeanVarDataType*>(arg.p_workspace_e_grid_),
454 arg.MRaw_,
455 arg.NRaw_,
456 arg.KRaw_,
457 std::array<index_t, 1>{arg.StrideA_}, // StrideAs
458 std::array<index_t, 1>{arg.StrideB_}, // StrideBs
459 arg.StrideDs_, // StrideDs
460 arg.StrideH_, // StrideE
461 I1, // kbatch
462 arg.a_element_op_,
463 arg.b_element_op_,
464 arg.cde_element_op_};
465
466 if(stream_config.log_level_ > 0)
467 {
468 gemm_arg.Print();
469 GridwiseGemmWelford::BlockwiseGemmPipe::HotLoopInstList::Print();
470 }
471
473 {
474 throw std::runtime_error("wrong! GridwiseGemmWelford has invalid setting");
475 }
476
477 if(arg.p_workspace_e_grid_ == nullptr || arg.p_workspace_mean_ == nullptr ||
478 arg.p_workspace_var_ == nullptr || arg.p_workspace_count_ == nullptr)
479 throw std::runtime_error("wrong! WorkSpace pointer has not been set");
480
481 index_t gdx, gdy, gdz;
482 std::tie(gdx, gdy, gdz) =
484
485 float ave_time = 0;
486
487 index_t K_split = (arg.KRaw_ + KPerBlock - 1) / KPerBlock * KPerBlock;
488
489 const bool has_main_k_block_loop =
491
492 const auto Run = [&](const auto& kernel_gemm_welford_first_half) {
493 // Note: cache flushing not supported
494
495 const auto kernel_welford_second_half =
497 EMeanVarDataType,
498 HDataType,
499 GammaDataType,
500 BetaDataType,
501 AccDataType,
506 HElementwiseOperation>;
507
508 // First kernel launch: GEMM + Welford first part
509 ave_time +=
510 launch_and_time_kernel(stream_config,
511 kernel_gemm_welford_first_half,
512 dim3(gdx, gdy, gdz),
513 dim3(BlockSize),
514 0,
515 gemm_arg,
516 static_cast<EMeanVarDataType*>(arg.p_workspace_mean_),
517 static_cast<EMeanVarDataType*>(arg.p_workspace_var_),
518 static_cast<int32_t*>(arg.p_workspace_count_));
519
520 // Second kernel launch: Welford second part
521 const auto M = arg.h_grid_desc_m_n_.GetLength(I0);
522 const auto N = arg.h_grid_desc_m_n_.GetLength(I1);
523
524 index_t MBlockClusterLength =
526 index_t NBlockClusterLength =
528
529 auto grid_size = MBlockClusterLength * NBlockClusterLength;
530
531 index_t numMeanVarCountBlockTileIteration_N = math::integer_divide_ceil(
532 arg.gemm_nblock_, LayernormThreadClusterSize_M_N::At(I1));
533
534 ave_time += launch_and_time_kernel(
535 stream_config,
536 kernel_welford_second_half,
537 dim3(grid_size),
538 dim3(BlockSize),
539 0,
540 static_cast<EMeanVarDataType*>(arg.p_workspace_e_grid_),
541 static_cast<const EMeanVarDataType*>(arg.p_workspace_mean_),
542 static_cast<const EMeanVarDataType*>(arg.p_workspace_var_),
543 static_cast<const int32_t*>(arg.p_workspace_count_),
544 arg.p_gamma_grid_,
545 arg.p_beta_grid_,
546 arg.p_h_grid_,
553 numMeanVarCountBlockTileIteration_N,
554 NBlockClusterLength,
555 arg.epsilon_,
556 arg.h_element_op_);
557 };
558
559 constexpr index_t minimum_occupancy = []() {
560 if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
561 {
562 return 2;
563 }
564 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
565 {
566 return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
567 }
568 else
569 {
570 return 1;
571 }
572 }();
573
574 if(has_main_k_block_loop)
575 {
576 // Tail number always full
577 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
578 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
579 {
582 EMeanVarDataType,
583 true,
585 minimum_occupancy>;
586 Run(kernel);
587 }
588 }
589 else
590 {
591 // Tail number always 1
592 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
593 {
596 EMeanVarDataType,
597 false,
599 minimum_occupancy>;
600 Run(kernel);
601 }
602 }
603
604 return ave_time;
605 }
606
607 // polymorphic
608 float Run(const BaseArgument* p_arg,
609 const StreamConfig& stream_config = StreamConfig{}) override
610 {
611 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
612 }
613 };
614
615 size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
616 {
617 const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
618
619 size_t workspace_size = 0;
620
621 int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_;
622
623 // workspace for welford intermediate mean
624 workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 128;
625
626 // workspace for welford intermediate variance
627 workspace_size += gemm_welford_size * sizeof(EMeanVarDataType) + 128;
628
629 // workspace for welford intermediate count
630 workspace_size += pArg_->gemm_nblock_ * sizeof(int32_t) + 128;
631
633 workspace_size += pArg_->MRaw_ * pArg_->NRaw_ * sizeof(EMeanVarDataType);
634
635 return (workspace_size);
636 };
637
639 void* p_workspace,
640 const StreamConfig& = StreamConfig{}) const override
641 {
642 Argument* pArg_ = dynamic_cast<Argument*>(pArg);
643
644 pArg_->p_workspace_ = p_workspace;
645
646 int gemm_welford_size = pArg_->MRaw_ * pArg_->gemm_nblock_;
647
648 // setup buffer used for intermediate welford mean
649 pArg_->p_workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
650
651 index_t mean_space_sz = gemm_welford_size * sizeof(EMeanVarDataType);
652 mean_space_sz = math::integer_least_multiple(mean_space_sz, 128);
653
654 // setup buffer used for intermediate welford variance
655 pArg_->p_workspace_var_ = reinterpret_cast<char*>(pArg_->p_workspace_mean_) + mean_space_sz;
656
657 index_t variance_space_sz = gemm_welford_size * sizeof(EMeanVarDataType);
658 variance_space_sz = math::integer_least_multiple(variance_space_sz, 128);
659
660 // setup buffer used for intermediate welford count
661 pArg_->p_workspace_count_ =
662 reinterpret_cast<char*>(pArg_->p_workspace_var_) + variance_space_sz;
663
664 index_t count_space_sz = gemm_welford_size * sizeof(int32_t);
665 count_space_sz = math::integer_least_multiple(count_space_sz, 128);
666
668 pArg_->p_workspace_e_grid_ =
669 reinterpret_cast<char*>(pArg_->p_workspace_count_) + count_space_sz;
670 else
671 pArg_->p_workspace_e_grid_ = static_cast<void*>(pArg_->p_h_grid_);
672 };
673
674 static bool IsSupportedArgument(const Argument& arg)
675 {
677 {
678 return false;
679 }
680
681 // No need to check for splitK because we force KBatch = 1 (no support)
682
683 if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
684 std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
685 {
687 {
688 return false;
689 }
690 }
691
692 if((arg.KRaw_ % AK1 != 0 || arg.KRaw_ % BK1 != 0) &&
693 !(GemmSpec == GemmSpecialization::MKPadding ||
694 GemmSpec == GemmSpecialization::NKPadding ||
695 GemmSpec == GemmSpecialization::MNKPadding ||
696 GemmSpec == GemmSpecialization::KPadding))
697 {
698 return false;
699 }
700
701 typename GridwiseGemmWelford::Argument gemm_arg{
702 std::array<const void*, 1>{arg.p_a_grid_},
703 std::array<const void*, 1>{arg.p_b_grid_},
704 arg.p_ds_grid_,
705 static_cast<EMeanVarDataType*>(arg.p_workspace_e_grid_),
706 arg.MRaw_,
707 arg.NRaw_,
708 arg.KRaw_,
709 std::array<index_t, 1>{arg.StrideA_}, // StrideAs
710 std::array<index_t, 1>{arg.StrideB_}, // StrideBs
711 arg.StrideDs_, // StrideDs
712 arg.StrideH_, // StrideE
713 I1, // kbatch
714 arg.a_element_op_,
715 arg.b_element_op_,
716 arg.cde_element_op_};
717
718 const auto a_grid_desc_ak0_m_ak1 =
719 GridwiseGemmWelford::MakeAsGridDescriptor_AK0_M_AK1(gemm_arg.M,
720 gemm_arg.MPadded,
721 gemm_arg.K,
722 gemm_arg.KPadded,
723 gemm_arg.StrideAs,
724 gemm_arg.AK0);
725 const auto b_grid_desc_bk0_n_bk1 =
726 GridwiseGemmWelford::MakeBsGridDescriptor_BK0_N_BK1(gemm_arg.K,
727 gemm_arg.KPadded,
728 gemm_arg.N,
729 gemm_arg.NPadded,
730 gemm_arg.StrideBs,
731 gemm_arg.BK0);
732
733 const auto M = a_grid_desc_ak0_m_ak1[I0].GetLength(I1);
734 const auto N = b_grid_desc_bk0_n_bk1[I0].GetLength(I1);
735 const auto K =
736 a_grid_desc_ak0_m_ak1[I0].GetLength(I0) * a_grid_desc_ak0_m_ak1[I0].GetLength(I2);
737
738 if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
739 {
740 return false;
741 }
742
743 return GridwiseGemmWelford::CheckValidity(gemm_arg);
744 }
745
746 // polymorphic
747 bool IsSupportedArgument(const BaseArgument* p_arg) override
748 {
749 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
750 }
751
752 static auto MakeArgument(const void* p_a,
753 const void* p_b,
754 std::array<const void*, NumDTensor> p_ds,
755 const void* p_gamma,
756 const void* p_beta,
757 void* p_h,
758 index_t MRaw,
759 index_t NRaw,
760 index_t KRaw,
761 index_t StrideA,
762 index_t StrideB,
763 std::array<index_t, NumDTensor> StrideDs,
764 index_t StrideH,
765 double epsilon,
766 AElementwiseOperation a_element_op,
767 BElementwiseOperation b_element_op,
768 CDEElementwiseOperation cde_element_op,
769 HElementwiseOperation h_element_op)
770 {
771 return Argument{p_a,
772 p_b,
773 p_ds,
774 p_gamma,
775 p_beta,
776 p_h,
777 MRaw,
778 NRaw,
779 KRaw,
780 StrideA,
781 StrideB,
782 StrideDs,
783 StrideH,
784 epsilon,
785 a_element_op,
786 b_element_op,
787 cde_element_op,
788 h_element_op};
789 }
790
791 static auto MakeInvoker() { return Invoker{}; }
792
793 // polymorphic
794 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
795 const void* p_b,
796 std::array<const void*, NumDTensor> p_ds,
797 const void* p_gamma,
798 const void* p_beta,
799 void* p_h,
800 index_t MRaw,
801 index_t NRaw,
802 index_t KRaw,
803 index_t StrideA,
804 index_t StrideB,
805 std::array<index_t, NumDTensor> StrideDs,
806 index_t StrideH,
807 double epsilon,
808 AElementwiseOperation a_element_op,
809 BElementwiseOperation b_element_op,
810 CDEElementwiseOperation cde_element_op,
811 HElementwiseOperation h_element_op) override
812 {
813 return std::make_unique<Argument>(p_a,
814 p_b,
815 p_ds,
816 p_gamma,
817 p_beta,
818 p_h,
819 MRaw,
820 NRaw,
821 KRaw,
822 StrideA,
823 StrideB,
824 StrideDs,
825 StrideH,
826 epsilon,
827 a_element_op,
828 b_element_op,
829 cde_element_op,
830 h_element_op);
831 }
832
833 // polymorphic
834 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
835 {
836 return std::make_unique<Invoker>(Invoker{});
837 }
838
839 // polymorphic
840 std::string GetTypeString() const override
841 {
842 auto str = std::stringstream();
843
844 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
847
848 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
854
855 // clang-format off
856 str << "DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3"
857 << ">"
858 << "BlkSize: "
859 << BlockSize << ", "
860 << "BlkTile: "
861 << MPerBlock << ", "
862 << NPerBlock << ", "
863 << KPerBlock << ", "
864 << "WaveTile: "
865 << MPerWmma << "x"<<NPerWmma << ", "
866 << "WaveMap: "
867 << MRepeat << "x" << NRepeat << ", "
868 << "VmemReadVec: "
869 << ABlockTransferSrcScalarPerVector << "x" << BBlockTransferSrcScalarPerVector << ", "
870 << "GemmSpec: "
871 << getGemmSpecializationString(GemmSpec) << ", "
872 << "VmemWriteThreadCluster: "
873 << CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I1) << ", "
874 << CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(I3) << ", "
875 << "LayerNormThreadCluster: "
876 << LayernormThreadClusterSize_M_N::At(I0) << ", "
877 << LayernormThreadClusterSize_M_N::At(I1) << ", "
878 << "LayerNormThreadSliceSize: "
879 << LayernormThreadSliceSize_M << ", "
880 << "BlkGemmPipelineScheduler: "
881 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
882 << "BlkGemmPipelineVersion: "
883 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
884 << "BlkGemmPipelinePrefetchStages: "
885 << GridwiseGemmWelford::BlockwiseGemmPipe::PrefetchStages << ", "
886 << "KPack: "
888 // clang-format on
889
890 return str.str();
891 }
892};
893
894} // namespace device
895} // namespace tensor_operation
896} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition utility/math.hpp:13
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
__host__ __device__ constexpr auto PadTensorDescriptor(const TensorDesc &desc, const TileLengths &tile_lengths, DoPads)
Definition matrix_padder.hpp:19
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__global__ void kernel_welford_layernorm2d_second_half(const EMeanVarDataType *__restrict__ p_e_grid, const EMeanVarDataType *__restrict__ p_in_welford_mean_grid, const EMeanVarDataType *__restrict__ p_in_welford_var_grid, const int32_t *__restrict__ p_in_welford_count_grid, const GammaDataType *__restrict__ p_gamma_grid, const BetaDataType *__restrict__ p_beta_grid, HDataType *__restrict__ p_h_grid, const EHGridDesc_M_N e_grid_desc_m_n, const EHGridDesc_M_N h_grid_desc_m_n, const LayernormMeanVarGridDesc_M_NBlock mean_var_grid_desc_m_nblock, const LayernormCountGridDesc_M_NBlock count_grid_desc_m_nblock, const GammaBetaGridDesc_N gamma_grid_desc_n, const GammaBetaGridDesc_N beta_grid_desc_n, index_t numMeanVarCountBlockTileIteration_N, index_t NBlockClusterLength, ComputeDataType epsilon, HElementwiseOperation h_element_op)
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:87
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Full
Definition blkgemmpipe_scheduler.hpp:49
__global__ void kernel_gemm_multiple_d_welford_first_half_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg, EMeanVarDataType *__restrict__ p_welford_mean_grid, EMeanVarDataType *__restrict__ p_welford_var_grid, int32_t *__restrict__ p_welford_count_grid)
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:33
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
typename remove_pointer< T >::type remove_pointer_t
Definition type.hpp:300
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
signed int int32_t
Definition stdint.h:123
Definition ck/stream_config.hpp:10
static __host__ constexpr bool CheckValidity(const Argument &karg)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:624
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:273
static constexpr index_t KPack
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:154
static __host__ constexpr bool CalculateHasMainKBlockLoop(index_t K)
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:837
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:233
Definition gridwise_welford_second_half_layernorm2d.hpp:42
Definition utility/sequence.hpp:43
Definition utility/tuple.hpp:117
Definition functional2.hpp:33
Definition device_base.hpp:197
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:333
const GammaDataType * p_gamma_grid_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:415
EHGridDesc_M_N layernorm_e_grid_desc_m_n_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:420
std::array< index_t, NumDTensor > StrideDs_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:438
const BDataType * p_b_grid_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:409
LayernormMeanVarGridDesc_M_NBlock layernorm_mean_var_grid_desc_m_nblock_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:421
void * p_workspace_mean_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:412
index_t KRaw_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:435
index_t gemm_nblock_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:440
index_t MRaw_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:433
BElementwiseOperation b_element_op_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:429
EHGridDesc_M_N h_grid_desc_m_n_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:425
void * p_workspace_count_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:414
index_t StrideA_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:436
AElementwiseOperation a_element_op_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:428
const ADataType * p_a_grid_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:408
GammaBetaGridDesc_N gamma_grid_desc_n_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:423
void * p_workspace_var_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:413
LayernormCountGridDesc_M_NBlock layernorm_count_grid_desc_m_nblock_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:422
index_t StrideB_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:437
HDataType * p_h_grid_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:417
index_t StrideH_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:439
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, const void *p_gamma_grid, const void *p_beta_grid, void *p_h_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideH, double epsilon, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, HElementwiseOperation h_element_op)
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:334
void * p_workspace_e_grid_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:411
HElementwiseOperation h_element_op_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:431
index_t NRaw_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:434
AccDataType epsilon_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:441
const BetaDataType * p_beta_grid_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:416
std::array< const void *, NumDTensor > p_ds_grid_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:410
GammaBetaGridDesc_N beta_grid_desc_n_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:424
CDEElementwiseOperation cde_element_op_
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:430
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:446
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:608
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:447
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:200
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:674
decltype(GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeMeanVarDescriptor_M_N< Sequence< true, true >, LayernormBlockTileSize_M_N::At(0), LayernormBlockTileSize_M_N::At(1)>(1, 1)) LayernormMeanVarGridDesc_M_NBlock
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:295
static constexpr index_t LayernormGammaSrcVectorSize
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:206
decltype(MakeEHGridDescriptor_M_N< Sequence< true, true >, 1, 1 >(1, 1, 1)) EHGridDesc_M_N
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:308
static constexpr index_t LayernormHDstVectorSize
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:205
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:204
Sequence< CDEShuffleBlockTransferScalarPerVector, CDEShuffleBlockTransferScalarPerVector, CDEShuffleBlockTransferScalarPerVector > CDEShuffleBlockTransferScalarPerVectors
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:220
static constexpr auto I2
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:217
static auto MakeEHGridDescriptor_M_N(index_t M, index_t N, index_t Stride)
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:280
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:834
std::string GetTypeString() const override
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:840
static constexpr index_t LayernormESrcVectorSize
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:208
static constexpr index_t LayernormBetaSrcVectorSize
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:207
Sequence< LayernormThreadClusterSize_M_N::At(0) *LayernormThreadSliceSize_M, LayernormThreadClusterSize_M_N::At(1) *LayernormThreadSliceSize_N > LayernormBlockTileSize_M_N
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:211
static auto MakeDescriptor_X(index_t X)
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:289
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, const void *p_gamma, const void *p_beta, void *p_h, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideH, double epsilon, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, HElementwiseOperation h_element_op)
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:752
static constexpr auto I0
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:215
GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, DsLayout, HLayout, Tuple< ADataType >, Tuple< BDataType >, AccDataType, CShuffleDataType, DsDataType, EMeanVarDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemmWelford
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:226
decltype(MakeDescriptor_X< LayernormBlockTileSize_M_N::At(1)>(1)) GammaBetaGridDesc_N
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:307
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:747
static constexpr auto I3
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:218
decltype(GridwiseGemmWelford::EpilogueWelfordCShuffle::template MakeCountDescriptor_M_N< Sequence< true, true >, LayernormBlockTileSize_M_N::At(0), LayernormBlockTileSize_M_N::At(1)>(1, 1)) LayernormCountGridDesc_M_NBlock
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:301
static constexpr auto I1
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:216
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, const void *p_gamma, const void *p_beta, void *p_h, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideH, double epsilon, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, HElementwiseOperation h_element_op) override
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:794
static auto MakeInvoker()
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:791
void SetWorkSpacePointer(BaseArgument *pArg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:638
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:615
static constexpr index_t LayernormThreadSliceSize_N
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:209
GridwiseWelfordSecondHalfLayernorm2d< EMeanVarDataType, HDataType, GammaDataType, BetaDataType, AccDataType, EHGridDesc_M_N, LayernormMeanVarGridDesc_M_NBlock, LayernormCountGridDesc_M_NBlock, GammaBetaGridDesc_N, HElementwiseOperation, BlockSize, LayernormThreadClusterSize_M_N::At(I0), LayernormThreadClusterSize_M_N::At(I1), LayernormThreadSliceSize_M, LayernormThreadSliceSize_N, LayernormESrcVectorSize, LayernormHDstVectorSize, LayernormGammaSrcVectorSize, LayernormBetaSrcVectorSize > GridwiseWelfordLayernorm
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:310
DeviceGemmMultipleDLayernorm_Wmma_CShuffleV3 DeviceOp
Definition device_gemm_multiple_d_layernorm_wmma_cshuffle_v3.hpp:202
Definition device_gemm_multiple_d_layernorm.hpp:40