device_reduce_common.hpp Source File

device_reduce_common.hpp Source File#

Composable Kernel: device_reduce_common.hpp Source File
device_reduce_common.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <vector>
7#include <cassert>
8
12
13namespace ck {
14namespace tensor_operation {
15namespace device {
16
17// here, inLengths[] is already shuffled so that lengths of invariant dims are included before those
18// of reduce dims
19template <index_t Rank, int NumReduceDim>
20std::pair<long_index_t, long_index_t> get_2d_lengths(const std::vector<index_t>& inLengths)
21{
22 static_assert(Rank <= 12, "bigger Rank size not supported!");
23
24 long_index_t invariant_total_length = 1;
25 long_index_t reduce_total_length = 1;
26
27 constexpr int NumInvariantDim = Rank - NumReduceDim;
28
29 for(int i = NumInvariantDim; i < Rank; i++)
30 reduce_total_length *= inLengths[i];
31
32 for(int i = 0; i < NumInvariantDim; i++)
33 invariant_total_length *= inLengths[i];
34
35 return std::make_pair(invariant_total_length, reduce_total_length);
36};
37
38template <index_t Rank, int NumReduceDim>
39std::pair<long_index_t, long_index_t> get_2d_lengths(const std::array<index_t, Rank>& inLengths)
40{
41 static_assert(Rank <= 12, "bigger Rank size not supported!");
42
43 long_index_t invariant_total_length = 1;
44 long_index_t reduce_total_length = 1;
45
46 constexpr int NumInvariantDim = Rank - NumReduceDim;
47
48 for(int i = NumInvariantDim; i < Rank; i++)
49 reduce_total_length *= inLengths[i];
50
51 for(int i = 0; i < NumInvariantDim; i++)
52 invariant_total_length *= inLengths[i];
53
54 return std::make_pair(invariant_total_length, reduce_total_length);
55};
56
57// helper functions using variadic template arguments
58template <index_t... Ns>
59auto make_tuple_from_array_and_index_seq(const std::vector<index_t>& lengths, Sequence<Ns...>)
60{
61 return make_tuple(static_cast<index_t>(lengths[Ns])...);
62};
63
64template <index_t arraySize>
65auto make_tuple_from_array(const std::vector<index_t>& lengths, Number<arraySize>)
66{
67 static_assert(arraySize >= 1 && arraySize <= 6, "The tensor should have 1 to 6 dimensions");
68
69 constexpr auto index_seq = typename arithmetic_sequence_gen<0, arraySize, 1>::type{};
70
71 return make_tuple_from_array_and_index_seq(lengths, index_seq);
72};
73
74template <index_t Rank, index_t NumReduceDim>
75std::vector<index_t> shuffle_tensor_dimensions(const std::vector<index_t>& origLengthsStrides,
76 const std::vector<int>& reduceDims)
77{
78 std::vector<index_t> newLengthsStrides;
79
80 assert(Rank == origLengthsStrides.size() && NumReduceDim == reduceDims.size());
81
82 int reduceFlag = 0;
83
84 // flag the bits for the reduceDims
85 for(int i = 0; i < NumReduceDim; i++)
86 {
87 reduceFlag |= 1 << reduceDims[i];
88 };
89
90 // collect invariant dimensions
91 for(int i = 0; i < Rank; i++)
92 if((reduceFlag & (1 << i)) == 0)
93 {
94 newLengthsStrides.push_back(origLengthsStrides[i]);
95 };
96
97 // collect reduce dimensions
98 for(int i = 0; i < Rank; i++)
99 if((reduceFlag & (1 << i)) > 0)
100 {
101 newLengthsStrides.push_back(origLengthsStrides[i]);
102 };
103
104 return newLengthsStrides;
105};
106
107template <index_t Rank, index_t NumReduceDim>
108std::array<index_t, Rank>
109shuffle_tensor_dimensions(const std::array<index_t, Rank>& origLengthsStrides,
110 const std::array<int, NumReduceDim>& reduceDims)
111{
112 std::array<index_t, Rank> newLengthsStrides;
113
114 int reduceFlag = 0;
115
116 // flag the bits for the reduceDims
117 for(int i = 0; i < NumReduceDim; i++)
118 {
119 reduceFlag |= 1 << reduceDims[i];
120 };
121
122 // collect invariant dimensions
123 int pos = 0;
124 for(int i = 0; i < Rank; i++)
125 if((reduceFlag & (1 << i)) == 0)
126 {
127 newLengthsStrides[pos++] = origLengthsStrides[i];
128 };
129
130 // collect reduce dimensions
131 for(int i = 0; i < Rank; i++)
132 if((reduceFlag & (1 << i)) > 0)
133 {
134 newLengthsStrides[pos++] = origLengthsStrides[i];
135 };
136
137 return newLengthsStrides;
138};
139
140} // namespace device
141} // namespace tensor_operation
142} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
auto make_tuple_from_array(const std::vector< index_t > &lengths, Number< arraySize >)
Definition device_reduce_common.hpp:65
std::pair< long_index_t, long_index_t > get_2d_lengths(const std::vector< index_t > &inLengths)
Definition device_reduce_common.hpp:20
std::vector< index_t > shuffle_tensor_dimensions(const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims)
Definition device_reduce_common.hpp:75
auto make_tuple_from_array_and_index_seq(const std::vector< index_t > &lengths, Sequence< Ns... >)
Definition device_reduce_common.hpp:59
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
int64_t long_index_t
Definition ck.hpp:300
Definition utility/sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition utility/sequence.hpp:271