matrix_padder.hpp Source File

matrix_padder.hpp Source File#

Composable Kernel: matrix_padder.hpp Source File
matrix_padder.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
10
11namespace ck {
12namespace tensor_operation {
13namespace device {
14
15template <typename TensorDesc,
16 typename TileLengths, // Tuple<...>
17 typename DoPads> // Sequence<bool, bool, ...>
18__host__ __device__ constexpr auto
19PadTensorDescriptor(const TensorDesc& desc, const TileLengths& tile_lengths, DoPads)
20{
21 constexpr index_t num_dim = DoPads::Size();
22
23 static_assert(num_dim == TileLengths::Size() && num_dim == TensorDesc::GetNumOfDimension(),
24 "wrong! inconsistent # of dimensions");
25
26 // transforms
27 const auto transforms = generate_tuple(
28 [&](auto idim) {
29 const auto MRaw = desc.GetLength(idim);
30
31 const auto MPerTile = tile_lengths[idim];
32
33 const auto M = math::integer_divide_ceil(MRaw, MPerTile) * MPerTile;
34
35 const auto MPad = M - MRaw;
36
37 const bool DoPadM = DoPads::At(idim);
38
39 const auto MTransform = conditional_expr<DoPadM>(make_right_pad_transform(MRaw, MPad),
41
42 return MTransform;
43 },
45
46 // lower dimension Id
47 const auto lower_dimss =
48 generate_tuple([&](auto idim) { return Sequence<idim.value>{}; }, Number<num_dim>{});
49
50 // upper dimension Id
51 const auto upper_dimss = lower_dimss;
52
53 return transform_tensor_descriptor(desc, transforms, lower_dimss, upper_dimss);
54}
55
56// M/N/K/OPerTileType could be index_t or Number<>
57template <GemmSpecialization GemmSpec,
58 typename MPerTileType,
59 typename NPerTileType,
60 typename KPerTileType,
61 typename OPerTileType>
63{
64 // TODO: hard to scale; use mask instead
65 static constexpr bool PadM =
70 static constexpr bool PadN =
75 static constexpr bool PadK =
80 static constexpr bool PadO =
85
86 // A[M, K]
87 template <typename ADesc_MRaw_KRaw>
88 __host__ __device__ constexpr auto
89 PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
90 {
92 a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence<PadM, PadK>{});
93 }
94
95 // B[K, N]
96 template <typename BDesc_NRaw_KRaw>
97 __host__ __device__ constexpr auto
98 PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
99 {
100 return PadTensorDescriptor(
101 b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
102 }
103
104 // B1[Gemm1N, Gemm1K] = B1[O, N]
105 template <typename B1Desc_NRaw_KRaw>
106 __host__ __device__ constexpr auto
107 PadB1Descriptor_N_K(const B1Desc_NRaw_KRaw& b1_desc_nraw_kraw) const
108 {
109 return PadTensorDescriptor(
110 b1_desc_nraw_kraw, make_tuple(OPerTile_, NPerTile_), Sequence<PadO, PadN>{});
111 }
112
113 // C[M, Gemm1N] = C[M, O]
114 template <typename CDesc_MRaw_NRaw>
115 __host__ __device__ constexpr auto
116 PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
117 {
118 return PadTensorDescriptor(
119 c_desc_mraw_nraw, make_tuple(MPerTile_, OPerTile_), Sequence<PadM, PadO>{});
120 }
121
122 MPerTileType MPerTile_;
123 NPerTileType NPerTile_;
124 KPerTileType KPerTile_;
125 OPerTileType OPerTile_;
126};
127
128// M/N/KPerTileType could be index_t or Number<>
129template <GemmSpecialization GemmSpec,
130 typename MPerTileType,
131 typename NPerTileType,
132 typename KPerTileType>
134{
135 static constexpr bool PadM =
138 static constexpr bool PadN =
141 static constexpr bool PadK =
144
145 template <typename ADesc_MRaw_KRaw>
146 __host__ __device__ constexpr auto
147 PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
148 {
149 return PadTensorDescriptor(
150 a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence<PadM, PadK>{});
151 }
152
153 template <typename BDesc_NRaw_KRaw>
154 __host__ __device__ constexpr auto
155 PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
156 {
157 return PadTensorDescriptor(
158 b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
159 }
160
161 template <typename CDesc_MRaw_NRaw>
162 __host__ __device__ constexpr auto
163 PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
164 {
165 return PadTensorDescriptor(
166 c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
167 }
168
169 MPerTileType MPerTile_;
170 NPerTileType NPerTile_;
171 KPerTileType KPerTile_;
172};
173
174// Alias of GemmPadder; to deprecate
175template <GemmSpecialization GemmSpec,
176 typename MPerTileType,
177 typename NPerTileType,
178 typename KPerTileType>
179struct MatrixPadder : public GemmPadder<GemmSpec, MPerTileType, NPerTileType, KPerTileType>
180{
181};
182
183// function to take in a struct of type MatrixPadder and call the appropriate function to get
184// the output descriptor at runtime for codegen
185template <GemmSpecialization GemmSpec,
186 typename MPerTileType,
187 typename NPerTileType,
188 typename KPerTileType,
189 typename CDesc_MRaw_NRaw>
191 CDesc_MRaw_NRaw conv_desc)
192{
193 auto res = matrix_padder.PadCDescriptor_M_N(conv_desc);
194 return res;
195}
196// M/N/KPerTileType could be index_t or Number<>
197template <bool PadM,
198 bool PadN,
199 bool PadK,
200 typename MPerTileType,
201 typename NPerTileType,
202 typename KPerTileType>
204{
205 template <typename ADesc_MRaw_KRaw>
206 __host__ __device__ constexpr auto
207 PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
208 {
209 return PadTensorDescriptor(
210 a_desc_mraw_kraw, make_tuple(MPerTile_, KPerTile_), Sequence<PadM, PadK>{});
211 }
212
213 template <typename BDesc_NRaw_KRaw>
214 __host__ __device__ constexpr auto
215 PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
216 {
217 return PadTensorDescriptor(
218 b_desc_nraw_kraw, make_tuple(NPerTile_, KPerTile_), Sequence<PadN, PadK>{});
219 }
220
221 template <typename CDesc_MRaw_NRaw>
222 __host__ __device__ constexpr auto
223 PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
224 {
225 return PadTensorDescriptor(
226 c_desc_mraw_nraw, make_tuple(MPerTile_, NPerTile_), Sequence<PadM, PadN>{});
227 }
228
229 MPerTileType MPerTile_;
230 NPerTileType NPerTile_;
231 KPerTileType KPerTile_;
232};
233
234// M/N/KPerTileType could be index_t or Number<>
235template <bool PadM,
236 bool PadN,
237 bool PadK,
238 typename MPerTileType,
239 typename NPerTileType,
240 typename KPerTileType>
242{
243 static constexpr auto I0 = Number<0>{};
244 static constexpr auto I1 = Number<1>{};
245 static constexpr auto I2 = Number<2>{};
246 static constexpr auto I3 = Number<3>{};
247
248 template <typename ADesc_MRaw_KRaw>
249 __host__ __device__ constexpr auto
250 PadADescriptor_M_K(const ADesc_MRaw_KRaw& a_desc_mraw_kraw) const
251 {
252 const auto MRaw = a_desc_mraw_kraw.GetLength(I0);
253 const auto KRaw = a_desc_mraw_kraw.GetLength(I1);
254
255 const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
256 const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
257
258 const auto MPad = M - MRaw;
259 const auto KPad = K - KRaw;
260
261 if constexpr(PadM && PadK)
262 {
263 // pad both M and K
264 return transform_tensor_descriptor(a_desc_mraw_kraw,
266 make_right_pad_transform(KRaw, KPad)),
269 }
270 else if constexpr(PadM && (!PadK))
271 {
272 // pad M, but not K
274 a_desc_mraw_kraw,
278 }
279 else if constexpr((!PadM) && PadK)
280 {
281 // pad K, but not M
283 a_desc_mraw_kraw,
287 }
288 else
289 {
290 // not pad M or K
291 return a_desc_mraw_kraw;
292 }
293 }
294
295 template <typename BDesc_NRaw_KRaw>
296 __host__ __device__ constexpr auto
297 PadBDescriptor_N_K(const BDesc_NRaw_KRaw& b_desc_nraw_kraw) const
298 {
299 const auto NRaw = b_desc_nraw_kraw.GetLength(I0);
300 const auto KRaw = b_desc_nraw_kraw.GetLength(I1);
301
302 const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
303 const auto K = math::integer_divide_ceil(KRaw, KPerTile_) * KPerTile_;
304
305 const auto NPad = N - NRaw;
306 const auto KPad = K - KRaw;
307
308 if constexpr(PadN && PadK)
309 {
310 // pad both N and K
311 return transform_tensor_descriptor(b_desc_nraw_kraw,
313 make_right_pad_transform(KRaw, KPad)),
316 }
317 else if constexpr(PadN && (!PadK))
318 {
319 // pad N, but not K
321 b_desc_nraw_kraw,
325 }
326 else if constexpr((!PadN) && PadK)
327 {
328 // pad K, but not N
330 b_desc_nraw_kraw,
334 }
335 else
336 {
337 // not pad N or K
338 return b_desc_nraw_kraw;
339 }
340 }
341
342 template <typename CDesc_MRaw_NRaw>
343 __host__ __device__ constexpr auto
344 PadCDescriptor_M_N(const CDesc_MRaw_NRaw& c_desc_mraw_nraw) const
345 {
346 const auto MRaw = c_desc_mraw_nraw.GetLength(I0);
347 const auto NRaw = c_desc_mraw_nraw.GetLength(I1);
348
349 const auto M = math::integer_divide_ceil(MRaw, MPerTile_) * MPerTile_;
350 const auto N = math::integer_divide_ceil(NRaw, NPerTile_) * NPerTile_;
351
352 const auto MPad = M - MRaw;
353 const auto NPad = N - NRaw;
354
355 if constexpr(PadM && PadN)
356 {
357 // pad M and N
358 return transform_tensor_descriptor(c_desc_mraw_nraw,
360 make_right_pad_transform(NRaw, NPad)),
363 }
364 else if constexpr(PadM && (!PadN))
365 {
366 // pad M, but not N
368 c_desc_mraw_nraw,
372 }
373 else if constexpr((!PadM) && PadN)
374 {
375 // pad N, but not M
377 c_desc_mraw_nraw,
381 }
382 else
383 {
384 // not pad M or N
385 return c_desc_mraw_nraw;
386 }
387 }
388
389 MPerTileType MPerTile_;
390 NPerTileType NPerTile_;
391 KPerTileType KPerTile_;
392};
393} // namespace device
394} // namespace tensor_operation
395} // namespace ck
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
Definition convolution_backward_data_specialization.hpp:8
auto grid_desc(MatrixPadder< GemmSpec, MPerTileType, NPerTileType, KPerTileType > matrix_padder, CDesc_MRaw_NRaw conv_desc)
Definition matrix_padder.hpp:190
__host__ __device__ constexpr auto PadTensorDescriptor(const TensorDesc &desc, const TileLengths &tile_lengths, DoPads)
Definition matrix_padder.hpp:19
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ NKOPadding
Definition gemm_specialization.hpp:28
@ MNOPadding
Definition gemm_specialization.hpp:26
@ KOPadding
Definition gemm_specialization.hpp:25
@ KPadding
Definition gemm_specialization.hpp:16
@ MOPadding
Definition gemm_specialization.hpp:23
@ OPadding
Definition gemm_specialization.hpp:22
@ NOPadding
Definition gemm_specialization.hpp:24
@ NPadding
Definition gemm_specialization.hpp:15
@ MPadding
Definition gemm_specialization.hpp:14
@ MNKOPadding
Definition gemm_specialization.hpp:29
@ MNKPadding
Definition gemm_specialization.hpp:20
@ MNPadding
Definition gemm_specialization.hpp:17
@ MKOPadding
Definition gemm_specialization.hpp:27
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
constexpr auto conditional_expr(X &&x, Y &&y)
Definition utility/functional.hpp:119
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
Definition utility/sequence.hpp:43
Definition matrix_padder.hpp:63
KPerTileType KPerTile_
Definition matrix_padder.hpp:124
OPerTileType OPerTile_
Definition matrix_padder.hpp:125
static constexpr bool PadM
Definition matrix_padder.hpp:65
MPerTileType MPerTile_
Definition matrix_padder.hpp:122
static constexpr bool PadN
Definition matrix_padder.hpp:70
__host__ __device__ constexpr auto PadADescriptor_M_K(const ADesc_MRaw_KRaw &a_desc_mraw_kraw) const
Definition matrix_padder.hpp:89
NPerTileType NPerTile_
Definition matrix_padder.hpp:123
static constexpr bool PadO
Definition matrix_padder.hpp:80
__host__ __device__ constexpr auto PadB1Descriptor_N_K(const B1Desc_NRaw_KRaw &b1_desc_nraw_kraw) const
Definition matrix_padder.hpp:107
static constexpr bool PadK
Definition matrix_padder.hpp:75
__host__ __device__ constexpr auto PadBDescriptor_N_K(const BDesc_NRaw_KRaw &b_desc_nraw_kraw) const
Definition matrix_padder.hpp:98
__host__ __device__ constexpr auto PadCDescriptor_M_N(const CDesc_MRaw_NRaw &c_desc_mraw_nraw) const
Definition matrix_padder.hpp:116
Definition matrix_padder.hpp:204
__host__ __device__ constexpr auto PadCDescriptor_M_N(const CDesc_MRaw_NRaw &c_desc_mraw_nraw) const
Definition matrix_padder.hpp:223
__host__ __device__ constexpr auto PadADescriptor_M_K(const ADesc_MRaw_KRaw &a_desc_mraw_kraw) const
Definition matrix_padder.hpp:207
MPerTileType MPerTile_
Definition matrix_padder.hpp:229
NPerTileType NPerTile_
Definition matrix_padder.hpp:230
KPerTileType KPerTile_
Definition matrix_padder.hpp:231
__host__ __device__ constexpr auto PadBDescriptor_N_K(const BDesc_NRaw_KRaw &b_desc_nraw_kraw) const
Definition matrix_padder.hpp:215
Definition matrix_padder.hpp:134
NPerTileType NPerTile_
Definition matrix_padder.hpp:170
MPerTileType MPerTile_
Definition matrix_padder.hpp:169
static constexpr bool PadK
Definition matrix_padder.hpp:141
KPerTileType KPerTile_
Definition matrix_padder.hpp:171
__host__ __device__ constexpr auto PadBDescriptor_N_K(const BDesc_NRaw_KRaw &b_desc_nraw_kraw) const
Definition matrix_padder.hpp:155
__host__ __device__ constexpr auto PadCDescriptor_M_N(const CDesc_MRaw_NRaw &c_desc_mraw_nraw) const
Definition matrix_padder.hpp:163
__host__ __device__ constexpr auto PadADescriptor_M_K(const ADesc_MRaw_KRaw &a_desc_mraw_kraw) const
Definition matrix_padder.hpp:147
static constexpr bool PadM
Definition matrix_padder.hpp:135
static constexpr bool PadN
Definition matrix_padder.hpp:138
Definition matrix_padder.hpp:242
KPerTileType KPerTile_
Definition matrix_padder.hpp:391
static constexpr auto I2
Definition matrix_padder.hpp:245
MPerTileType MPerTile_
Definition matrix_padder.hpp:389
__host__ __device__ constexpr auto PadCDescriptor_M_N(const CDesc_MRaw_NRaw &c_desc_mraw_nraw) const
Definition matrix_padder.hpp:344
__host__ __device__ constexpr auto PadBDescriptor_N_K(const BDesc_NRaw_KRaw &b_desc_nraw_kraw) const
Definition matrix_padder.hpp:297
static constexpr auto I3
Definition matrix_padder.hpp:246
static constexpr auto I0
Definition matrix_padder.hpp:243
__host__ __device__ constexpr auto PadADescriptor_M_K(const ADesc_MRaw_KRaw &a_desc_mraw_kraw) const
Definition matrix_padder.hpp:250
NPerTileType NPerTile_
Definition matrix_padder.hpp:390
static constexpr auto I1
Definition matrix_padder.hpp:244
Definition matrix_padder.hpp:180