device_reduce_threadwise.hpp Source File

device_reduce_threadwise.hpp Source File#

Composable Kernel: device_reduce_threadwise.hpp Source File
device_reduce_threadwise.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8#include <array>
9
15
16namespace ck {
17namespace tensor_operation {
18namespace device {
19
20template <typename InDataType,
21 typename AccDataType,
22 typename OutDataType,
23 index_t Rank,
24 index_t NumReduceDim,
25 typename ReduceOperation,
26 typename InElementwiseOperation,
27 typename AccElementwiseOperation,
28 bool PropagateNan,
29 bool OutputIndex,
30 bool TransformIndexKtoGlobal,
31 bool HaveIndexInputIfOutputIndex,
32 index_t BlockSize,
33 index_t MThreadSliceSize,
34 index_t KThreadSliceSize,
35 index_t InSrcVectorDim,
36 index_t InSrcVectorSize,
37 index_t OutDstVectorSize>
38struct DeviceReduceThreadWise : public DeviceReduce<InDataType,
39 AccDataType,
40 OutDataType,
41 Rank,
42 NumReduceDim,
43 ReduceOperation,
44 InElementwiseOperation,
45 AccElementwiseOperation,
46 PropagateNan,
47 OutputIndex>
48
49{
50 static_assert(Rank <= 12, "Bigger Rank size is not supported!");
51
52 static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
53 (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
54 (MThreadSliceSize % OutDstVectorSize == 0),
55 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
56
58
59 static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex;
60
61 static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
62
63 static constexpr index_t NumSrcDim = Rank;
64 static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
65 static constexpr bool reduceAllDim = (NumInvariantDim == 0);
66
67 static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize;
68 static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize;
69
70 static auto MakeSrc2dDescriptor(const std::array<index_t, Rank>& inLengths,
71 const std::array<index_t, Rank>& inStrides)
72 {
73 const auto tupleSrcLengths =
74 generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
75 const auto tupleSrcStrides =
76 generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
77
78 const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
79
80 const auto in_grid_desc_m_k = [&]() {
81 if constexpr(reduceAllDim)
82 {
83 const auto one_dim_inDesc = transform_tensor_descriptor(
84 inDesc,
85 make_tuple(make_merge_transform(tupleSrcLengths)),
88
89 return transform_tensor_descriptor(one_dim_inDesc,
91 1, one_dim_inDesc.GetLength(Number<0>{})))),
94 }
95 else
96 {
97 using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
99
100 const auto reduceDimLengths = generate_tuple(
101 [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
102 const auto invariantDimLengths =
103 generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
104
106 inDesc,
107 make_tuple(make_merge_transform(invariantDimLengths),
108 make_merge_transform(reduceDimLengths)),
109 make_tuple(InvariantDims{}, ReduceDims{}),
111 }
112 }();
113
114 const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
115 const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
116
117 const auto inPad_M =
118 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
119 const auto inPad_K =
120 math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
121
122 auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
123 in_grid_desc_m_k,
124 make_tuple(make_right_pad_transform(invariantLength, inPad_M),
125 make_right_pad_transform(reduceLength, inPad_K)),
128
129 return (in_grid_desc_m_k_padded);
130 };
131
132 static auto MakeDst1dDescriptor(const std::array<index_t, NumDstDim>& outLengths,
133 const std::array<index_t, NumDstDim>& outStrides)
134 {
135 const auto tupleDstLengths =
136 generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
137 const auto tupleDstStrides =
138 generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
139
140 auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
141
142 auto out_grid_desc_m = transform_tensor_descriptor(
143 outDesc,
144 make_tuple(make_merge_transform(tupleDstLengths)),
147
148 const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
149
150 const auto outPad =
151 math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
152
153 auto out_grid_desc_m_padded = transform_tensor_descriptor(
154 out_grid_desc_m,
155 make_tuple(make_right_pad_transform(invariantLength, outPad)),
158 return (out_grid_desc_m_padded);
159 };
160
161 struct Argument : public BaseArgument
162 {
163 Argument(const std::array<index_t, Rank> inLengths,
164 const std::array<index_t, Rank> inStrides,
165 const std::array<index_t, NumDstDim> outLengths,
166 const std::array<index_t, NumDstDim> outStrides,
167 const std::array<int, NumReduceDim> reduceDims,
168 double alpha,
169 double beta,
170 const InDataType* in_dev,
171 OutDataType* out_dev,
172 IndexDataType* out_index_dev,
173 const InElementwiseOperation in_elementwise_op,
174 const AccElementwiseOperation acc_elementwise_op)
175 : outLengths_{outLengths},
176 outStrides_{outStrides},
177 in_dev_{in_dev},
178 out_dev_{out_dev},
179 out_index_dev_{out_index_dev},
180 in_elementwise_op_{in_elementwise_op},
181 acc_elementwise_op_{acc_elementwise_op}
182 {
185
188
191
192 if constexpr(NumInvariantDim == 0)
194 else
196
198
200
203 }
204
205 std::array<index_t, Rank> inLengths_;
206 std::array<index_t, Rank> inStrides_;
207 std::array<index_t, NumDstDim> outLengths_;
208 std::array<index_t, NumDstDim> outStrides_;
209
210 AccDataType alpha_;
211 AccDataType beta_;
212
213 const InDataType* in_dev_;
214 OutDataType* out_dev_;
216
217 InElementwiseOperation in_elementwise_op_;
218 AccElementwiseOperation acc_elementwise_op_;
219
224
226 size_t gridSize;
227 };
228
229 struct Invoker : public BaseInvoker
230 {
231 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
232 {
233 const auto in_grid_desc_m_k =
235 const auto out_grid_desc_m =
237 using InGridDesc_M_K = decltype(in_grid_desc_m_k);
238 using OutGridDesc_M = decltype(out_grid_desc_m);
239
240 float avg_time = 0;
241
242 using GridwiseReduce =
244 OutDataType,
245 AccDataType,
247 InGridDesc_M_K,
248 OutGridDesc_M,
249 ReduceOperation,
250 InElementwiseOperation,
251 AccElementwiseOperation,
253 PropagateNan,
254 BlockSize,
255 MThreadSliceSize,
256 KThreadSliceSize,
257 InSrcVectorDim,
258 InSrcVectorSize,
259 OutDstVectorSize>;
260
261 const auto kernel = kernel_reduce_threadwise<GridwiseReduce,
262 OutputIndex,
263 TransformIndexKtoGlobal,
265 InDataType,
266 OutDataType,
267 AccDataType,
269 InGridDesc_M_K,
270 OutGridDesc_M,
271 InElementwiseOperation,
272 AccElementwiseOperation>;
273
274 avg_time = launch_and_time_kernel(stream_config,
275 kernel,
276 dim3(arg.gridSize),
277 dim3(BlockSize),
278 0,
279 in_grid_desc_m_k,
280 out_grid_desc_m,
283 arg.alpha_,
284 arg.in_dev_,
285 nullptr,
286 arg.beta_,
287 arg.out_dev_,
288 arg.out_index_dev_);
289
290 return (avg_time);
291 };
292
293 float Run(const BaseArgument* p_arg,
294 const StreamConfig& stream_config = StreamConfig{}) override
295 {
296 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
297 };
298 };
299
300 bool IsSupportedArgument(const BaseArgument* p_arg) override
301 {
302 const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
303
304 if constexpr(InSrcVectorDim == 0)
305 {
306 if constexpr(NumInvariantDim == 0)
307 {
308 return (false);
309 }
310 else
311 {
312 if(pArg->inStrides_[NumInvariantDim - 1] != 1)
313 return (false);
314
315 if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
316 return (false);
317 };
318 }
319 else
320 {
321 if(pArg->inStrides_[Rank - 1] != 1)
322 return (false);
323
324 if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
325 return (false);
326 };
327
328 // To improve
329 if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
330 return (false);
331
332 // cases with big reduce_total_length should be handled by Blockwise kernel
333 if(pArg->reduce_total_length / KThreadSliceSize >= 32)
334 return (false);
335
336 return (true);
337 };
338
339 std::unique_ptr<BaseArgument>
340 MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
341 const std::array<index_t, Rank> inStrides,
342 const std::array<index_t, NumDstDim> outLengths,
343 const std::array<index_t, NumDstDim> outStrides,
344 const std::array<int, NumReduceDim> reduceDims,
345 double alpha,
346 double beta,
347 const void* in_dev,
348 const void* in_index_dev,
349 void* out_dev,
350 void* out_index_dev,
351 const InElementwiseOperation in_elementwise_op,
352 const AccElementwiseOperation acc_elementwise_op) override
353 {
354 (void)in_index_dev;
355
356 return std::make_unique<Argument>(inLengths,
357 inStrides,
358 outLengths,
359 outStrides,
360 reduceDims,
361 alpha,
362 beta,
363 static_cast<const InDataType*>(in_dev),
364 static_cast<OutDataType*>(out_dev),
365 static_cast<IndexDataType*>(out_index_dev),
366 in_elementwise_op,
367 acc_elementwise_op);
368 };
369
370 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
371 {
372 return std::make_unique<Invoker>();
373 };
374
375 std::string GetTypeString() const override
376 {
377 auto str = std::stringstream();
378
379 // clang-format off
380 str << "DeviceReduceThreadWise<" << BlockSize << ",";
381 str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ",";
382 str << "K_C" << 1 << "_S" << KThreadSliceSize << ",";
383 str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
384 // clang-format on
385
386 return str.str();
387 }
388};
389
390} // namespace device
391} // namespace tensor_operation
392} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr auto integer_least_multiple(X x, Y y)
Definition utility/math.hpp:78
Definition convolution_backward_data_specialization.hpp:8
std::pair< long_index_t, long_index_t > get_2d_lengths(const std::vector< index_t > &inLengths)
Definition device_reduce_common.hpp:20
std::vector< index_t > shuffle_tensor_dimensions(const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims)
Definition device_reduce_common.hpp:75
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition gridwise_2d_reduction_threadwise.hpp:28
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
int64_t long_index_t
Definition ck.hpp:300
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
signed int int32_t
Definition stdint.h:123
Definition ck/stream_config.hpp:10
Definition gridwise_2d_reduction_threadwise.hpp:84
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271
Definition device_base.hpp:197
Definition device_reduce.hpp:27
Definition device_reduce_threadwise.hpp:162
std::array< index_t, Rank > inLengths_
Definition device_reduce_threadwise.hpp:205
std::array< index_t, NumDstDim > outStrides_
Definition device_reduce_threadwise.hpp:208
InElementwiseOperation in_elementwise_op_
Definition device_reduce_threadwise.hpp:217
int numBlockTileIteration
Definition device_reduce_threadwise.hpp:225
IndexDataType * out_index_dev_
Definition device_reduce_threadwise.hpp:215
index_t invariant_lowest_length
Definition device_reduce_threadwise.hpp:220
OutDataType * out_dev_
Definition device_reduce_threadwise.hpp:214
AccDataType alpha_
Definition device_reduce_threadwise.hpp:210
size_t gridSize
Definition device_reduce_threadwise.hpp:226
AccDataType beta_
Definition device_reduce_threadwise.hpp:211
AccElementwiseOperation acc_elementwise_op_
Definition device_reduce_threadwise.hpp:218
std::array< index_t, Rank > inStrides_
Definition device_reduce_threadwise.hpp:206
std::array< index_t, NumDstDim > outLengths_
Definition device_reduce_threadwise.hpp:207
index_t reduce_lowest_length
Definition device_reduce_threadwise.hpp:221
long_index_t invariant_total_length
Definition device_reduce_threadwise.hpp:222
long_index_t reduce_total_length
Definition device_reduce_threadwise.hpp:223
Argument(const std::array< index_t, Rank > inLengths, const std::array< index_t, Rank > inStrides, const std::array< index_t, NumDstDim > outLengths, const std::array< index_t, NumDstDim > outStrides, const std::array< int, NumReduceDim > reduceDims, double alpha, double beta, const InDataType *in_dev, OutDataType *out_dev, IndexDataType *out_index_dev, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op)
Definition device_reduce_threadwise.hpp:163
const InDataType * in_dev_
Definition device_reduce_threadwise.hpp:213
Definition device_reduce_threadwise.hpp:230
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_reduce_threadwise.hpp:293
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_reduce_threadwise.hpp:231
Definition device_reduce_threadwise.hpp:49
static constexpr index_t NumSrcDim
Definition device_reduce_threadwise.hpp:63
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, Rank > inLengths, const std::array< index_t, Rank > inStrides, const std::array< index_t, NumDstDim > outLengths, const std::array< index_t, NumDstDim > outStrides, const std::array< int, NumReduceDim > reduceDims, double alpha, double beta, const void *in_dev, const void *in_index_dev, void *out_dev, void *out_index_dev, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op) override
Definition device_reduce_threadwise.hpp:340
int32_t IndexDataType
Definition device_reduce_threadwise.hpp:57
static constexpr index_t NumDstDim
Definition device_reduce_threadwise.hpp:64
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_reduce_threadwise.hpp:300
std::string GetTypeString() const override
Definition device_reduce_threadwise.hpp:375
static constexpr index_t NumInvariantDim
Definition device_reduce_threadwise.hpp:61
static constexpr index_t M_BlockTileSize
Definition device_reduce_threadwise.hpp:67
static constexpr bool reduceAllDim
Definition device_reduce_threadwise.hpp:65
static constexpr bool HaveIndexInput
Definition device_reduce_threadwise.hpp:59
static constexpr index_t K_BlockTileSize
Definition device_reduce_threadwise.hpp:68
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_reduce_threadwise.hpp:370
static auto MakeSrc2dDescriptor(const std::array< index_t, Rank > &inLengths, const std::array< index_t, Rank > &inStrides)
Definition device_reduce_threadwise.hpp:70
static auto MakeDst1dDescriptor(const std::array< index_t, NumDstDim > &outLengths, const std::array< index_t, NumDstDim > &outStrides)
Definition device_reduce_threadwise.hpp:132