device_normalization_fwd_splitk_impl.hpp Source File

device_normalization_fwd_splitk_impl.hpp Source File#

Composable Kernel: device_normalization_fwd_splitk_impl.hpp Source File
device_normalization_fwd_splitk_impl.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
18
19namespace ck {
20template <typename GridwiseWelford,
21 typename XDataType,
22 typename WorkspaceMeanVarDataType,
23 typename ComputeDataType,
24 typename XGridDesc_M_K,
25 typename MeanVarGridDesc_M_KBlock>
26__global__ void
27kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k,
28 const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock,
29 index_t num_k_block_tile_iteration,
30 const XDataType* const __restrict__ p_x_global,
31 WorkspaceMeanVarDataType* const __restrict__ p_welford_mean,
32 WorkspaceMeanVarDataType* const __restrict__ p_welford_variance,
33 int32_t* const __restrict__ p_welford_count)
34{
35 GridwiseWelford::Run(x_grid_desc_m_k,
36 mean_var_grid_desc_m_kblock,
37 num_k_block_tile_iteration,
38 p_x_global,
39 p_welford_mean,
40 p_welford_variance,
41 p_welford_count);
42};
43
44template <typename GridwiseWelfordNormalization,
45 typename WorkspaceMeanVarDataType,
46 typename XDataType,
47 typename GammaDataType,
48 typename BetaDataType,
49 typename YDataType,
50 typename SaveMeanInvStdDataType,
51 typename ComputeDataType,
52 typename YElementwiseOperation,
53 typename MeanVarGridDesc_M_KBlock,
54 typename CountGridDesc_M_KBlock,
55 typename XYGammaBetaGridDesc_M_K,
56 typename SaveMeanInvStdGridDesc_M>
57__global__ void
58kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock,
59 const CountGridDesc_M_KBlock count_grid_desc_m_kblock,
60 const XYGammaBetaGridDesc_M_K x_grid_desc_m_k,
61 const XYGammaBetaGridDesc_M_K gamma_grid_desc_m_k,
62 const XYGammaBetaGridDesc_M_K beta_grid_desc_m_k,
63 const XYGammaBetaGridDesc_M_K y_grid_desc_m_k,
64 const SaveMeanInvStdGridDesc_M save_mean_grid_desc_m,
65 const SaveMeanInvStdGridDesc_M save_inv_std_grid_desc_m,
66 index_t num_k_mean_var_count_iteration,
67 index_t num_k_block_tile_iteration,
68 index_t k_grid_size,
69 ComputeDataType epsilon,
70 const WorkspaceMeanVarDataType* const p_mean_global,
71 const WorkspaceMeanVarDataType* const p_variance_global,
72 const int32_t* const p_welford_count_global,
73 const XDataType* const __restrict__ p_x_global,
74 const GammaDataType* const __restrict__ p_gamma_global,
75 const BetaDataType* const __restrict__ p_beta_global,
76 YDataType* const __restrict__ p_y_global,
77 SaveMeanInvStdDataType* const __restrict__ p_save_mean_global,
78 SaveMeanInvStdDataType* const __restrict__ p_save_inv_std_global,
79 const YElementwiseOperation y_elementwise_op)
80{
81 GridwiseWelfordNormalization::Run(mean_var_grid_desc_m_kblock,
82 count_grid_desc_m_kblock,
83 x_grid_desc_m_k,
84 gamma_grid_desc_m_k,
85 beta_grid_desc_m_k,
86 y_grid_desc_m_k,
87 save_mean_grid_desc_m,
88 save_inv_std_grid_desc_m,
89 num_k_mean_var_count_iteration,
90 num_k_block_tile_iteration,
91 k_grid_size,
92 epsilon,
93 p_mean_global,
94 p_variance_global,
95 p_welford_count_global,
96 p_x_global,
97 p_gamma_global,
98 p_beta_global,
99 p_y_global,
100 p_save_mean_global,
101 p_save_inv_std_global,
102 y_elementwise_op);
103};
104} // namespace ck
105
106namespace ck {
107namespace tensor_operation {
108namespace device {
109
110// Y = Normalization(X, Beta, Gamma)
111// M: Invariant length
112// K: Reduce length (Calculate mean and variance along K dimension)
113// eg. Length = [N, C, H, W], reduce dim = [C, H, W]
114// Then, M = N, K = C * H * W
115template <typename XDataType,
116 typename GammaDataType,
117 typename BetaDataType,
118 typename ComputeDataType,
119 typename YDataType,
120 typename SaveMeanInvStdDataType,
121 typename YElementwiseOperation,
122 index_t Rank,
123 index_t NumReduceDim,
124 index_t BlockSize,
125 index_t MThreadClusterSize,
126 index_t KThreadClusterSize,
127 index_t MThreadSliceSize,
128 index_t KThreadSliceSize,
129 index_t XYVectorDim,
130 index_t XSrcVectorSize,
131 index_t GammaSrcVectorDim,
132 index_t GammaSrcVectorSize,
133 index_t BetaSrcVectorDim,
134 index_t BetaSrcVectorSize,
135 index_t YDstVectorSize,
136 index_t SaveMeanInvStdDstVectorSize>
138 GammaDataType,
139 BetaDataType,
140 YDataType,
141 SaveMeanInvStdDataType,
142 YElementwiseOperation,
143 Rank,
144 NumReduceDim>
145{
146 using WorkspaceMeanVarDataType = SaveMeanInvStdDataType;
147
148 static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize);
149 static_assert(
150 ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) ||
151 (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)),
152 "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
153
154 static_assert(
155 ((BetaSrcVectorDim == 0 && MThreadSliceSize % BetaSrcVectorSize == 0) ||
156 (BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)),
157 "Invalid thread slice sizes and/or beta vector sizes configuration, please check!");
158
159 static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
160 "Invalid thread slice sizes and/or save mean and inverse std vector sizes "
161 "configuration, please check!");
162
164
165 static constexpr auto I0 = Number<0>{};
166 static constexpr auto I1 = Number<1>{};
167
168 static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
169 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
170 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
171
172 static constexpr bool reduceAllDim = (NumInvariantDim == 0);
173 static_assert(!reduceAllDim); // TODO
174
175 static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
176 const std::vector<index_t>& inStrides,
177 int kBlockSize,
178 int numBlockTileIteration)
179 {
180 static constexpr index_t numSrcDim = Rank;
181
182 const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
183 const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
184
185 const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
186
187 const auto in_grid_desc_m_k = [&]() {
188 if constexpr(reduceAllDim)
189 {
190 const auto one_dim_inDesc = transform_tensor_descriptor(
191 inDesc,
192 make_tuple(make_merge_transform(tupleSrcLengths)),
195
196 return transform_tensor_descriptor(one_dim_inDesc,
198 1, one_dim_inDesc.GetLength(Number<0>{})))),
201 }
202 else
203 {
204 using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
206
207 const auto reduceDimLengths =
208 make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
209 const auto invariantDimLengths =
210 make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
211
213 inDesc,
214 make_tuple(make_merge_transform(invariantDimLengths),
215 make_merge_transform(reduceDimLengths)),
216 make_tuple(InvariantDims{}, ReduceDims{}),
218 }
219 }();
220
221 const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
222 const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
223
224 const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
225 const auto inPad_M =
226 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
227 const auto inPad_K = reduceSizePerBlock * kBlockSize - reduceLength;
228
229 auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
230 in_grid_desc_m_k,
231 make_tuple(make_right_pad_transform(invariantLength, inPad_M),
232 make_right_pad_transform(reduceLength, inPad_K)),
235
236 return (in_grid_desc_m_k_padded);
237 };
238
239 template <typename DoPads, index_t MPerTile, index_t KPerTile>
241 {
242 const auto grid_desc_m_k =
244 return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{});
245 }
246
247 template <typename DoPads, index_t MPerTile, index_t KPerTile>
249 {
250 const auto grid_desc_m_k =
252 return PadTensorDescriptor(grid_desc_m_k, make_tuple(MPerTile, KPerTile), DoPads{});
253 }
254
255 static auto MakeSaveMeanInvStdDescriptor_M(const std::vector<index_t>& lengths,
256 const std::vector<index_t>& strides)
257 {
258 using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
259
260 const auto tupleSrcLengths = make_tuple_from_array_and_index_seq(lengths, InvariantDims{});
261 const auto tupleSrcStrides = make_tuple_from_array_and_index_seq(strides, InvariantDims{});
262
263 const auto desc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
264
265 const auto grid_desc_m =
267 make_tuple(make_merge_transform(tupleSrcLengths)),
268 make_tuple(InvariantDims{}),
270
271 const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
272 const auto pad_M =
273 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
274
275 auto grid_desc_m_padded = transform_tensor_descriptor(
276 grid_desc_m,
277 make_tuple(make_right_pad_transform(invariantLength, pad_M)),
280
281 return grid_desc_m_padded;
282 }
283
284 using SrcGridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
287
290
293
295
297 ComputeDataType,
301 BlockSize,
302 MThreadClusterSize,
303 KThreadClusterSize,
304 MThreadSliceSize,
305 KThreadSliceSize,
306 XYVectorDim,
307 XSrcVectorSize>;
308
311 XDataType,
312 GammaDataType,
313 BetaDataType,
314 YDataType,
315 SaveMeanInvStdDataType,
316 ComputeDataType,
317 YElementwiseOperation,
322 BlockSize,
323 MThreadClusterSize,
324 KThreadClusterSize,
325 MThreadSliceSize,
326 KThreadSliceSize,
327 XYVectorDim,
328 XSrcVectorSize,
329 GammaSrcVectorDim,
330 GammaSrcVectorSize,
331 BetaSrcVectorDim,
332 BetaSrcVectorSize,
333 XYVectorDim,
334 YDstVectorSize,
335 SaveMeanInvStdDstVectorSize>;
336
337 struct Argument : public BaseArgument
338 {
339 Argument(const std::vector<index_t> lengths,
340 const std::vector<index_t> xStrides,
341 const std::vector<index_t> gammaStrides,
342 const std::vector<index_t> betaStrides,
343 const std::vector<index_t> yStrides,
344 const std::vector<index_t> saveMeanStrides,
345 const std::vector<index_t> saveInvStdStrides,
346 const std::vector<index_t> reduceDims,
347 YElementwiseOperation y_elementwise_op,
348 double epsilon,
349 const XDataType* p_x,
350 const GammaDataType* p_gamma,
351 const BetaDataType* p_beta,
352 YDataType* p_y,
353 SaveMeanInvStdDataType* p_saveMean,
354 SaveMeanInvStdDataType* p_saveInvStd)
355 : p_x_(p_x),
356 p_gamma_(p_gamma),
357 p_beta_(p_beta),
358 p_y_(p_y),
359 p_saveMean_(p_saveMean),
360 p_saveInvStd_(p_saveInvStd),
361 p_workspace_mean_{nullptr},
362 p_workspace_var_{nullptr},
363 p_workspace_count_{nullptr},
364 y_elementwise_op_(y_elementwise_op)
365 {
366 epsilon_ = static_cast<ComputeDataType>(epsilon);
367
373 saveMeanStrides_ = saveMeanStrides;
374 saveInvStdStrides_ = saveInvStdStrides;
375
377
379 while(true)
380 {
381 int testKGridSize =
383
384 // we want the kGridSize_ be not more than 128
385 if(testKGridSize <= 128)
386 break;
387
389 };
390
393
394 // We do not use vector load for mean, var and count
395 static constexpr index_t K_MeanVarCountBlockTileSize = KThreadClusterSize;
396
398 math::integer_divide_ceil(kGridSize_, K_MeanVarCountBlockTileSize);
399
408
411
412 // We don't need to pad in K dimension for Welford1. Set KPerTile 1.
416
420 K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_);
421
425 K_MeanVarCountBlockTileSize>(MRaw_, kGridSize_);
426
427 if constexpr(NumInvariantDim == 0)
429 else
431 }
432
433 ComputeDataType epsilon_;
434
435 const XDataType* p_x_;
436 const GammaDataType* p_gamma_;
437 const BetaDataType* p_beta_;
438 YDataType* p_y_;
439 SaveMeanInvStdDataType* p_saveMean_;
440 SaveMeanInvStdDataType* p_saveInvStd_;
444
445 std::vector<index_t> Lengths_;
446 std::vector<index_t> xStrides_;
447 std::vector<index_t> gammaStrides_;
448 std::vector<index_t> betaStrides_;
449 std::vector<index_t> yStrides_;
450 std::vector<index_t> saveMeanStrides_;
451 std::vector<index_t> saveInvStdStrides_;
452
453 YElementwiseOperation y_elementwise_op_;
454
458 size_t gridSize_;
459
466
470
471 index_t MRaw_; // Invariant length
472 index_t KRaw_; // reduce length
473
475 };
476
477 struct Invoker : public BaseInvoker
478 {
479 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
480 {
481 if(arg.p_workspace_mean_ == nullptr || arg.p_workspace_var_ == nullptr ||
482 arg.p_workspace_count_ == nullptr)
483 throw std::runtime_error("wrong! WorkSpace pointer has not been set");
484
486 XDataType,
488 ComputeDataType,
491
494 XDataType,
495 GammaDataType,
496 BetaDataType,
497 YDataType,
498 SaveMeanInvStdDataType,
499 ComputeDataType,
500 YElementwiseOperation,
505
506 float avg_time = 0;
507 avg_time += launch_and_time_kernel(
508 stream_config,
509 kernel1,
510 dim3(arg.gridSize_),
511 dim3(BlockSize),
512 0,
516 arg.p_x_,
519 static_cast<int32_t*>(arg.p_workspace_count_));
520
521 avg_time += launch_and_time_kernel(
522 stream_config,
523 kernel2,
524 dim3(arg.gridSize_),
525 dim3(BlockSize),
526 0,
537 arg.kGridSize_,
538 arg.epsilon_,
539 static_cast<const WorkspaceMeanVarDataType*>(arg.p_workspace_mean_),
540 static_cast<const WorkspaceMeanVarDataType*>(arg.p_workspace_var_),
541 static_cast<const int32_t*>(arg.p_workspace_count_),
542 arg.p_x_,
543 arg.p_gamma_,
544 arg.p_beta_,
545 arg.p_y_,
546 arg.p_saveMean_,
547 arg.p_saveInvStd_,
549
550 return avg_time;
551 };
552
553 float Run(const BaseArgument* p_arg,
554 const StreamConfig& stream_config = StreamConfig{}) override
555 {
556 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
557 };
558 };
559
560 size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
561 {
562 const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
563
564 size_t workspace_size = 0;
565
566 int welford_size = pArg_->MRaw_ * pArg_->kGridSize_;
567
568 // workspace for welford intermediate mean
569 workspace_size += welford_size * sizeof(WorkspaceMeanVarDataType) + 64;
570
571 // workspace for welford intermediate variance
572 workspace_size += welford_size * sizeof(WorkspaceMeanVarDataType) + 64;
573
574 // workspace for welford intermediate count
575 workspace_size += pArg_->kGridSize_ * sizeof(int32_t) + 64;
576
577 return (workspace_size);
578 };
579
581 void* p_workspace,
582 const StreamConfig& = StreamConfig{}) const override
583 {
584 Argument* pArg_ = dynamic_cast<Argument*>(pArg);
585
586 pArg_->p_workspace_ = p_workspace;
587
588 int welford_size = pArg_->MRaw_ * pArg_->kGridSize_;
589
590 // setup buffer used for intermediate welford mean
591 pArg_->p_workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
592
593 index_t mean_space_sz = welford_size * sizeof(WorkspaceMeanVarDataType);
594 mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
595
596 // setup buffer used for intermediate welford varirance
597 pArg_->p_workspace_var_ = reinterpret_cast<char*>(pArg_->p_workspace_mean_) + mean_space_sz;
598
599 index_t variance_space_sz = welford_size * sizeof(WorkspaceMeanVarDataType);
600 variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
601
602 // setup buffer used for intermediate welford count
603 pArg_->p_workspace_count_ =
604 reinterpret_cast<char*>(pArg_->p_workspace_var_) + variance_space_sz;
605 };
606
607 bool IsSupportedArgument(const BaseArgument* p_arg) override
608 {
609 const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
610
611 if constexpr(XYVectorDim == 0)
612 {
613 if constexpr(NumInvariantDim == 0)
614 {
615 return false;
616 }
617 else
618 {
619 if(p_arg_->xStrides_[NumInvariantDim - 1] != 1)
620 return false;
621
622 if(p_arg_->invariant_lowest_length_ % XSrcVectorSize != 0)
623 return false;
624
625 if(p_arg_->invariant_lowest_length_ % YDstVectorSize != 0)
626 return false;
627 };
628 }
629 else
630 {
631 if(p_arg_->xStrides_[Rank - 1] != 1)
632 return false;
633
634 if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0)
635 return false;
636
637 if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0)
638 return false;
639 };
640
641 // if fastest dim is not reduced
642 if constexpr(GammaSrcVectorDim == 0)
643 {
644 if(p_arg_->gammaStrides_[NumInvariantDim - 1] != 1)
645 return false;
646
647 if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0)
648 return false;
649 }
650 else // if fastest dim is reduced
651 {
652 if(p_arg_->gammaStrides_[Rank - 1] != 1)
653 return false;
654
655 if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0)
656 return false;
657 }
658
659 // if fastest dim is not reduced
660 if constexpr(BetaSrcVectorDim == 0)
661 {
662 if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1)
663 return false;
664
665 if(p_arg_->invariant_lowest_length_ % BetaSrcVectorSize != 0)
666 return false;
667 }
668 else // if fastest dim is reduced
669 {
670 if(p_arg_->betaStrides_[Rank - 1] != 1)
671 return false;
672
673 if(p_arg_->Lengths_[Rank - 1] % BetaSrcVectorSize != 0)
674 return false;
675 }
676
677 if(p_arg_->kGridSize_ <= 1)
678 return false;
679
680 if(p_arg_->invariant_lowest_length_ % SaveMeanInvStdDstVectorSize != 0)
681 return false;
682
683 return true;
684 };
685
686 std::unique_ptr<BaseArgument>
687 MakeArgumentPointer(const std::vector<index_t> lengths,
688 const std::vector<index_t> xStrides,
689 const std::vector<index_t> gammaStrides,
690 const std::vector<index_t> betaStrides,
691 const std::vector<index_t> yStrides,
692 const std::vector<index_t> saveMeanStrides,
693 const std::vector<index_t> saveInvStdStrides,
694 const std::vector<index_t> reduceDims,
695 double epsilon,
696 const void* p_x,
697 const void* p_gamma,
698 const void* p_beta,
699 void* p_y,
700 void* p_saveMean,
701 void* p_saveInvStd,
702 YElementwiseOperation y_elementwise_op) override
703 {
704 if(lengths.size() != Rank || xStrides.size() != Rank || gammaStrides.size() != Rank ||
705 betaStrides.size() != Rank || yStrides.size() != Rank ||
706 saveMeanStrides.size() != NumInvariantDim || saveInvStdStrides.size() != NumInvariantDim)
707 throw std::runtime_error("dimension is incorrect");
708
709 return std::make_unique<Argument>(lengths,
710 xStrides,
711 gammaStrides,
712 betaStrides,
713 yStrides,
714 saveMeanStrides,
715 saveInvStdStrides,
716 reduceDims,
717 y_elementwise_op,
718 epsilon,
719 static_cast<const XDataType*>(p_x),
720 static_cast<const GammaDataType*>(p_gamma),
721 static_cast<const BetaDataType*>(p_beta),
722 static_cast<YDataType*>(p_y),
723 static_cast<SaveMeanInvStdDataType*>(p_saveMean),
724 static_cast<SaveMeanInvStdDataType*>(p_saveInvStd));
725 };
726
727 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
728 {
729 return std::make_unique<Invoker>();
730 };
731
732 std::string GetTypeString() const override
733 {
734 auto str = std::stringstream();
735
736 // clang-format off
737 str << "DeviceNormalizationFwdSplitKImpl<" << BlockSize << ",";
738 str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ",";
739 str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ",";
740 str << "XYSrcVectorDim_" << XYVectorDim << ",";
741 str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">";
742 // clang-format on
743
744 return str.str();
745 }
746};
747
748} // namespace device
749} // namespace tensor_operation
750} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
auto make_tuple_from_array(const std::vector< index_t > &lengths, Number< arraySize >)
Definition device_reduce_common.hpp:65
__host__ __device__ constexpr auto PadTensorDescriptor(const TensorDesc &desc, const TileLengths &tile_lengths, DoPads)
Definition matrix_padder.hpp:19
std::pair< long_index_t, long_index_t > get_2d_lengths(const std::vector< index_t > &inLengths)
Definition device_reduce_common.hpp:20
std::vector< index_t > shuffle_tensor_dimensions(const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims)
Definition device_reduce_common.hpp:75
auto make_tuple_from_array_and_index_seq(const std::vector< index_t > &lengths, Sequence< Ns... >)
Definition device_reduce_common.hpp:59
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__global__ void kernel_normalizationSplitK2nd(const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, const CountGridDesc_M_KBlock count_grid_desc_m_kblock, const XYGammaBetaGridDesc_M_K x_grid_desc_m_k, const XYGammaBetaGridDesc_M_K gamma_grid_desc_m_k, const XYGammaBetaGridDesc_M_K beta_grid_desc_m_k, const XYGammaBetaGridDesc_M_K y_grid_desc_m_k, const SaveMeanInvStdGridDesc_M save_mean_grid_desc_m, const SaveMeanInvStdGridDesc_M save_inv_std_grid_desc_m, index_t num_k_mean_var_count_iteration, index_t num_k_block_tile_iteration, index_t k_grid_size, ComputeDataType epsilon, const WorkspaceMeanVarDataType *const p_mean_global, const WorkspaceMeanVarDataType *const p_variance_global, const int32_t *const p_welford_count_global, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const BetaDataType *const __restrict__ p_beta_global, YDataType *const __restrict__ p_y_global, SaveMeanInvStdDataType *const __restrict__ p_save_mean_global, SaveMeanInvStdDataType *const __restrict__ p_save_inv_std_global, const YElementwiseOperation y_elementwise_op)
Definition device_normalization_fwd_splitk_impl.hpp:58
int32_t index_t
Definition ck.hpp:299
__global__ void kernel_normalizationSplitK1st(const XGridDesc_M_K x_grid_desc_m_k, const MeanVarGridDesc_M_KBlock mean_var_grid_desc_m_kblock, index_t num_k_block_tile_iteration, const XDataType *const __restrict__ p_x_global, WorkspaceMeanVarDataType *const __restrict__ p_welford_mean, WorkspaceMeanVarDataType *const __restrict__ p_welford_variance, int32_t *const __restrict__ p_welford_count)
Definition device_normalization_fwd_splitk_impl.hpp:27
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
signed int int32_t
Definition stdint.h:123
Definition ck/stream_config.hpp:10
Definition gridwise_normalization_splitk_1st.hpp:28
Definition gridwise_normalization_splitk_2nd.hpp:42
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition device_base.hpp:197
Definition device_normalization_fwd.hpp:23
Definition device_normalization_fwd_splitk_impl.hpp:338
std::vector< index_t > Lengths_
Definition device_normalization_fwd_splitk_impl.hpp:445
SaveMeanInvStdGridDesc_M save_mean_grid_desc_m_
Definition device_normalization_fwd_splitk_impl.hpp:464
Kernel2MeanVarGridDesc_M_KBlock kernel2_mean_var_grid_desc_m_kblock_
Definition device_normalization_fwd_splitk_impl.hpp:468
std::vector< index_t > gammaStrides_
Definition device_normalization_fwd_splitk_impl.hpp:447
std::vector< index_t > betaStrides_
Definition device_normalization_fwd_splitk_impl.hpp:448
std::vector< index_t > saveInvStdStrides_
Definition device_normalization_fwd_splitk_impl.hpp:451
Kernel2CountGridDesc_M_KBlock kernel2_count_grid_desc_m_kblock_
Definition device_normalization_fwd_splitk_impl.hpp:469
void * p_workspace_mean_
Definition device_normalization_fwd_splitk_impl.hpp:441
SrcGridDesc_M_K y_grid_desc_m_k_
Definition device_normalization_fwd_splitk_impl.hpp:463
int kGridSize_
Definition device_normalization_fwd_splitk_impl.hpp:455
SrcGridDesc_M_K gamma_grid_desc_m_k_
Definition device_normalization_fwd_splitk_impl.hpp:461
std::vector< index_t > saveMeanStrides_
Definition device_normalization_fwd_splitk_impl.hpp:450
index_t KRaw_
Definition device_normalization_fwd_splitk_impl.hpp:472
const XDataType * p_x_
Definition device_normalization_fwd_splitk_impl.hpp:435
SaveMeanInvStdGridDesc_M save_inv_std_grid_desc_m_
Definition device_normalization_fwd_splitk_impl.hpp:465
void * p_workspace_var_
Definition device_normalization_fwd_splitk_impl.hpp:442
YElementwiseOperation y_elementwise_op_
Definition device_normalization_fwd_splitk_impl.hpp:453
size_t gridSize_
Definition device_normalization_fwd_splitk_impl.hpp:458
std::vector< index_t > xStrides_
Definition device_normalization_fwd_splitk_impl.hpp:446
int numBlockTileIteration_
Definition device_normalization_fwd_splitk_impl.hpp:457
std::vector< index_t > yStrides_
Definition device_normalization_fwd_splitk_impl.hpp:449
ComputeDataType epsilon_
Definition device_normalization_fwd_splitk_impl.hpp:433
void * p_workspace_count_
Definition device_normalization_fwd_splitk_impl.hpp:443
const BetaDataType * p_beta_
Definition device_normalization_fwd_splitk_impl.hpp:437
SrcGridDesc_M_K beta_grid_desc_m_k_
Definition device_normalization_fwd_splitk_impl.hpp:462
SrcGridDesc_M_K x_grid_desc_m_k_
Definition device_normalization_fwd_splitk_impl.hpp:460
SaveMeanInvStdDataType * p_saveMean_
Definition device_normalization_fwd_splitk_impl.hpp:439
YDataType * p_y_
Definition device_normalization_fwd_splitk_impl.hpp:438
index_t invariant_lowest_length_
Definition device_normalization_fwd_splitk_impl.hpp:474
const GammaDataType * p_gamma_
Definition device_normalization_fwd_splitk_impl.hpp:436
index_t MRaw_
Definition device_normalization_fwd_splitk_impl.hpp:471
SaveMeanInvStdDataType * p_saveInvStd_
Definition device_normalization_fwd_splitk_impl.hpp:440
Kernel1MeanVarGridDesc_M_KBlock kernel1_mean_var_grid_desc_m_kblock_
Definition device_normalization_fwd_splitk_impl.hpp:467
int numMeanVarCountIteration_
Definition device_normalization_fwd_splitk_impl.hpp:456
Argument(const std::vector< index_t > lengths, const std::vector< index_t > xStrides, const std::vector< index_t > gammaStrides, const std::vector< index_t > betaStrides, const std::vector< index_t > yStrides, const std::vector< index_t > saveMeanStrides, const std::vector< index_t > saveInvStdStrides, const std::vector< index_t > reduceDims, YElementwiseOperation y_elementwise_op, double epsilon, const XDataType *p_x, const GammaDataType *p_gamma, const BetaDataType *p_beta, YDataType *p_y, SaveMeanInvStdDataType *p_saveMean, SaveMeanInvStdDataType *p_saveInvStd)
Definition device_normalization_fwd_splitk_impl.hpp:339
Definition device_normalization_fwd_splitk_impl.hpp:478
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_normalization_fwd_splitk_impl.hpp:479
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_normalization_fwd_splitk_impl.hpp:553
Definition device_normalization_fwd_splitk_impl.hpp:145
decltype(MakeWorkspaceMeanVarDescriptor_M_K< Sequence< true, true >, 1, 1 >(1, 1)) Kernel2MeanVarGridDesc_M_KBlock
Definition device_normalization_fwd_splitk_impl.hpp:288
static constexpr auto I0
Definition device_normalization_fwd_splitk_impl.hpp:165
void SetWorkSpacePointer(BaseArgument *pArg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition device_normalization_fwd_splitk_impl.hpp:580
decltype(MakeWorkspaceMeanVarDescriptor_M_K< Sequence< true, false >, 1, 1 >(1, 1)) Kernel1MeanVarGridDesc_M_KBlock
Definition device_normalization_fwd_splitk_impl.hpp:285
decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1})) SaveMeanInvStdGridDesc_M
Definition device_normalization_fwd_splitk_impl.hpp:294
static constexpr index_t K_BlockTileSize
Definition device_normalization_fwd_splitk_impl.hpp:170
static auto MakeWorkspaceCountDescriptor_M_K(index_t M, index_t K)
Definition device_normalization_fwd_splitk_impl.hpp:248
static auto MakeSaveMeanInvStdDescriptor_M(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition device_normalization_fwd_splitk_impl.hpp:255
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition device_normalization_fwd_splitk_impl.hpp:560
decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)) SrcGridDesc_M_K
Definition device_normalization_fwd_splitk_impl.hpp:284
static constexpr index_t NumInvariantDim
Definition device_normalization_fwd_splitk_impl.hpp:168
GridwiseNormalizationSplitK1st< XDataType, ComputeDataType, WorkspaceMeanVarDataType, SrcGridDesc_M_K, Kernel1MeanVarGridDesc_M_KBlock, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize > GridwiseWelford
Definition device_normalization_fwd_splitk_impl.hpp:296
SaveMeanInvStdDataType WorkspaceMeanVarDataType
Definition device_normalization_fwd_splitk_impl.hpp:146
static constexpr auto I1
Definition device_normalization_fwd_splitk_impl.hpp:166
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_normalization_fwd_splitk_impl.hpp:607
decltype(MakeWorkspaceCountDescriptor_M_K< Sequence< true, true >, 1, 1 >(1, 1)) Kernel2CountGridDesc_M_KBlock
Definition device_normalization_fwd_splitk_impl.hpp:291
static auto MakeSrc2dDescriptor(const std::vector< index_t > &inLengths, const std::vector< index_t > &inStrides, int kBlockSize, int numBlockTileIteration)
Definition device_normalization_fwd_splitk_impl.hpp:175
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_normalization_fwd_splitk_impl.hpp:727
static constexpr index_t M_BlockTileSize
Definition device_normalization_fwd_splitk_impl.hpp:169
GridwiseNormalizationSplitK2nd< WorkspaceMeanVarDataType, XDataType, GammaDataType, BetaDataType, YDataType, SaveMeanInvStdDataType, ComputeDataType, YElementwiseOperation, Kernel2MeanVarGridDesc_M_KBlock, Kernel2CountGridDesc_M_KBlock, SrcGridDesc_M_K, SaveMeanInvStdGridDesc_M, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, XYVectorDim, YDstVectorSize, SaveMeanInvStdDstVectorSize > GridwiseWelfordNormalization
Definition device_normalization_fwd_splitk_impl.hpp:309
static auto MakeWorkspaceMeanVarDescriptor_M_K(index_t M, index_t K)
Definition device_normalization_fwd_splitk_impl.hpp:240
tensor_operation::element_wise::PassThrough PassThrough
Definition device_normalization_fwd_splitk_impl.hpp:163
std::string GetTypeString() const override
Definition device_normalization_fwd_splitk_impl.hpp:732
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::vector< index_t > lengths, const std::vector< index_t > xStrides, const std::vector< index_t > gammaStrides, const std::vector< index_t > betaStrides, const std::vector< index_t > yStrides, const std::vector< index_t > saveMeanStrides, const std::vector< index_t > saveInvStdStrides, const std::vector< index_t > reduceDims, double epsilon, const void *p_x, const void *p_gamma, const void *p_beta, void *p_y, void *p_saveMean, void *p_saveInvStd, YElementwiseOperation y_elementwise_op) override
Definition device_normalization_fwd_splitk_impl.hpp:687
static constexpr bool reduceAllDim
Definition device_normalization_fwd_splitk_impl.hpp:172
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340