transform_conv_bwd_weight_to_gemm_v2.hpp Source File

transform_conv_bwd_weight_to_gemm_v2.hpp Source File#

Composable Kernel: transform_conv_bwd_weight_to_gemm_v2.hpp Source File
transform_conv_bwd_weight_to_gemm_v2.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
12
13namespace ck {
14namespace tensor_operation {
15
25template <index_t NDimSpatial,
26 index_t MPerBlock,
27 index_t NPerBlock,
28 index_t GemmK1Number,
29 index_t K0PerBlock,
30 index_t NumGroupsToMerge,
31 device::ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization>
33{
34 static constexpr auto I0 = Number<0>{};
35 static constexpr auto I1 = Number<1>{};
36
37 template <index_t NDim, typename enable_if<NDim == 1, bool>::type = false>
38 constexpr static auto
40 const index_t Wo,
41 const index_t K,
42 const std::array<index_t, NDimSpatial + 3>& output_strides)
43 {
44 const index_t BatchStride = output_strides[0];
45 const index_t WoStride = output_strides[3];
46 const auto KStride = Number<1>{};
47 return make_naive_tensor_descriptor(make_tuple(N * Wo, NumGroupsToMerge, K),
48 make_tuple(WoStride, BatchStride, KStride));
49 }
50
51 template <index_t NDim, typename enable_if<NDim == 1, bool>::type = false>
52 constexpr static auto
54 const index_t Wi,
55 const index_t C,
56 const std::array<index_t, NDimSpatial + 3>& input_strides)
57 {
58 const index_t BatchStride = input_strides[0];
59 const index_t NStride = input_strides[1];
60 const index_t WiStride = input_strides[3];
61 const auto CStride = input_strides[2];
62 if constexpr(ConvBackwardWeightSpecialization ==
64 {
65 return make_naive_tensor_descriptor(make_tuple(N * Wi, NumGroupsToMerge, C),
66 make_tuple(WiStride, BatchStride, CStride));
67 }
68 else
69 {
71 make_tuple(N, Wi, NumGroupsToMerge, C),
72 make_tuple(NStride, WiStride, BatchStride, CStride));
73 }
74 }
75
76 template <index_t NDim, typename enable_if<NDim == 1, bool>::type = false>
77 constexpr static auto
79 const index_t X,
80 const index_t C,
81 const std::array<index_t, NDimSpatial + 3>& weights_strides)
82 {
83 const auto CStride = Number<1>{};
84 const auto KStride = weights_strides[1];
85 const auto XStride = weights_strides[3];
86 const auto BatchStride = weights_strides[0];
87 // Add NumGroupsToMerge for Batch+M dimension and, 1 as a placehorder
88 // for Batch+N dimension
89 const auto desc = make_naive_tensor_descriptor(
90 make_tuple(NumGroupsToMerge, K, X, 1, C),
91 make_tuple(BatchStride, KStride, XStride, BatchStride, CStride));
92 // Padd 1 to NumGroupsToMerge
93 const auto padded_desc = transform_tensor_descriptor(
94 desc,
98 make_pad_transform(1, 0, NumGroupsToMerge - 1),
102 // We need only matrices from diagonal. Xor returns 0 for the same
103 // values. So if matrices is not on diagonal then it will be stored in padding.
104 // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
105 static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
106 NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || NumGroupsToMerge == 32 ||
107 NumGroupsToMerge == 64);
108 const auto unmerged_padded_desc = transform_tensor_descriptor(
109 padded_desc,
110 make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
116 // Merge To M, N
118 unmerged_padded_desc,
119 make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K)),
120 make_merge_transform(make_tuple(X, NumGroupsToMerge, C))),
123 }
124
125 template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
126 constexpr static auto
128 const index_t Ho,
129 const index_t Wo,
130 const index_t K,
131 const std::array<index_t, NDimSpatial + 3>& output_strides)
132 {
133 const index_t BatchStride = output_strides[0];
134 const index_t WoStride = output_strides[4];
135 const auto KStride = Number<1>{};
136 return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, NumGroupsToMerge, K),
137 make_tuple(WoStride, BatchStride, KStride));
138 }
139
140 template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
141 constexpr static auto
143 const index_t Hi,
144 const index_t Wi,
145 const index_t C,
146 const std::array<index_t, NDimSpatial + 3>& input_strides)
147 {
148 const index_t BatchStride = input_strides[0];
149 const index_t NStride = input_strides[1];
150 const index_t HiStride = input_strides[3];
151 const index_t WiStride = input_strides[4];
152 const auto CStride = input_strides[2];
153 if constexpr(ConvBackwardWeightSpecialization ==
155 {
156 return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, NumGroupsToMerge, C),
157 make_tuple(WiStride, BatchStride, CStride));
158 }
159 else
160 {
162 make_tuple(N, Hi, Wi, NumGroupsToMerge, C),
163 make_tuple(NStride, HiStride, WiStride, BatchStride, CStride));
164 }
165 }
166
167 template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
168 constexpr static auto
170 const index_t Y,
171 const index_t X,
172 const index_t C,
173 const std::array<index_t, NDimSpatial + 3>& weights_strides)
174 {
175 const auto CStride = Number<1>{};
176 const auto KStride = weights_strides[1];
177 const auto XStride = weights_strides[4];
178 const auto BatchStride = weights_strides[0];
179 // Add NumGroupsToMerge for Batch+M dimension and, 1 as a placehorder
180 // for Batch+N dimension
181 const auto desc = make_naive_tensor_descriptor(
182 make_tuple(NumGroupsToMerge, K, Y * X, 1, C),
183 make_tuple(BatchStride, KStride, XStride, BatchStride, CStride));
184 // Padd 1 to NumGroupsToMerge
185 const auto padded_desc = transform_tensor_descriptor(
186 desc,
187 make_tuple(make_pass_through_transform(NumGroupsToMerge),
190 make_pad_transform(1, 0, NumGroupsToMerge - 1),
194 // We need only matrices from diagonal. Xor returns 0 for the same
195 // values. So if matrices is not on diagonal then it will be stored in padding.
196 // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
197 static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
198 NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || NumGroupsToMerge == 32 ||
199 NumGroupsToMerge == 64);
200 const auto unmerged_padded_desc = transform_tensor_descriptor(
201 padded_desc,
202 make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
208 // Merge To M, N
210 unmerged_padded_desc,
211 make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K)),
212 make_merge_transform(make_tuple(Y * X, NumGroupsToMerge, C))),
215 }
216
217 template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
218 constexpr static auto
220 const index_t Do,
221 const index_t Ho,
222 const index_t Wo,
223 const index_t K,
224 const std::array<index_t, NDimSpatial + 3>& output_strides)
225 {
226 const index_t BatchStride = output_strides[0];
227 const index_t WoStride = output_strides[5];
228 const auto KStride = Number<1>{};
229 return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, NumGroupsToMerge, K),
230 make_tuple(WoStride, BatchStride, KStride));
231 }
232
233 template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
234 constexpr static auto
236 const index_t Di,
237 const index_t Hi,
238 const index_t Wi,
239 const index_t C,
240 const std::array<index_t, NDimSpatial + 3>& input_strides)
241 {
242 const index_t BatchStride = input_strides[0];
243 const index_t NStride = input_strides[1];
244 const index_t DiStride = input_strides[3];
245 const index_t HiStride = input_strides[4];
246 const index_t WiStride = input_strides[5];
247 const auto CStride = input_strides[2];
248 if constexpr(ConvBackwardWeightSpecialization ==
250 {
251 return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, NumGroupsToMerge, C),
252 make_tuple(WiStride, BatchStride, CStride));
253 }
254 else
255 {
257 make_tuple(N, Di, Hi, Wi, NumGroupsToMerge, C),
258 make_tuple(NStride, DiStride, HiStride, WiStride, BatchStride, CStride));
259 }
260 }
261
262 template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
263 constexpr static auto
265 const index_t Z,
266 const index_t Y,
267 const index_t X,
268 const index_t C,
269 const std::array<index_t, NDimSpatial + 3>& weights_strides)
270 {
271 const auto CStride = Number<1>{};
272 const auto KStride = weights_strides[1];
273 const auto XStride = weights_strides[5];
274 const auto BatchStride = weights_strides[0];
275 // Add NumGroupsToMerge for Batch+M dimension and, 1 for placehord for Batch+N dimension
276 const auto desc = make_naive_tensor_descriptor(
277 make_tuple(NumGroupsToMerge, K, Z * Y * X, 1, C),
278 make_tuple(BatchStride, KStride, XStride, BatchStride, CStride));
279 // Padd 1 to NumGroupsToMerge
280 const auto padded_desc = transform_tensor_descriptor(
281 desc,
282 make_tuple(make_pass_through_transform(NumGroupsToMerge),
285 make_pad_transform(1, 0, NumGroupsToMerge - 1),
289 // We need only matrices from diagonal. Xor returns 0 for the same
290 // values. So if matrices is not on diagonal then it will be stored in padding.
291 // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
292 static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
293 NumGroupsToMerge == 8 || NumGroupsToMerge == 16 || NumGroupsToMerge == 32 ||
294 NumGroupsToMerge == 64);
295 const auto unmerged_padded_desc = transform_tensor_descriptor(
296 padded_desc,
297 make_tuple(make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
303 // Merge To M, N
305 unmerged_padded_desc,
306 make_tuple(make_merge_transform(make_tuple(NumGroupsToMerge, K)),
307 make_merge_transform(make_tuple(Z * Y * X, NumGroupsToMerge, C))),
310 }
311
312 template <index_t NDim, typename enable_if<NDim == 1, bool>::type = false>
314 const index_t N,
315 const index_t K,
316 const index_t C,
317 const std::array<index_t, NDimSpatial>& input_spatial_lengths,
318 const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
319 const std::array<index_t, NDimSpatial>& output_spatial_lengths,
320 const std::array<index_t, NDimSpatial + 3>& input_strides,
321 const std::array<index_t, NDimSpatial + 3>& weights_strides,
322 const std::array<index_t, NDimSpatial + 3>& output_strides,
323 const std::array<index_t, NDimSpatial>& conv_filter_strides,
324 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
325 const std::array<index_t, NDimSpatial>& input_left_pads,
326 const std::array<index_t, NDimSpatial>& input_right_pads,
327 const index_t batch_k)
328 {
329 using namespace ck;
330
331 const index_t Wi = input_spatial_lengths[0];
332
333 const index_t Wo = output_spatial_lengths[0];
334
335 const index_t X = filter_spatial_lengths[0];
336
337 const index_t ConvStrideW = conv_filter_strides[0];
338
339 const index_t ConvDilationW = conv_filter_dilations[0];
340
341 const index_t InLeftPadW = input_left_pads[0];
342
343 const index_t InRightPadW = input_right_pads[0];
344
345 const index_t GemmKTotal = N * Wo;
346 const index_t GemmM = K * NumGroupsToMerge;
347 const index_t GemmN = C * X * NumGroupsToMerge;
348
349 const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
350 const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
351
352 const index_t GemmKBatch = batch_k;
353 const index_t GemmK0 =
354 math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
355 K0PerBlock;
356 const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
357
358 const auto out_grid_desc = make_out_grid_desc<NDim>(N, Wo, K, output_strides);
359 const auto in_grid_desc = make_in_grid_desc<NDim>(N, Wi, C, input_strides);
360 const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, X, C, weights_strides);
361
362 if constexpr(ConvBackwardWeightSpecialization ==
364 {
365 // A: output tensor
366 const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
367 out_grid_desc,
369 make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
370 make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))),
373
374 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
375 out_gemmkpad_gemmm_grid_desc,
376 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
377 make_right_pad_transform(GemmM, PadGemmM)),
380
381 // B: input tensor
382 const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
383 in_grid_desc,
385 make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
386 make_merge_transform(make_tuple(NumGroupsToMerge, GemmN / NumGroupsToMerge))),
389
390 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
391 in_gemmkpad_gemmn_grid_desc,
392 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
393 make_right_pad_transform(GemmN, PadGemmN)),
396
397 // Padd
398 const auto wei_gemmm_gemmn_pad_grid_desc =
399 transform_tensor_descriptor(wei_grid_desc,
400 make_tuple(make_right_pad_transform(GemmM, PadGemmM),
401 make_right_pad_transform(GemmN, PadGemmN)),
404
405 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
406 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
407 wei_gemmm_gemmn_pad_grid_desc);
408 }
409 else
410 {
411 // A: output tensor
412 const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
413 out_grid_desc,
415 make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
416 make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))),
419
420 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
421 out_gemmkpad_gemmm_grid_desc,
422 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
423 make_right_pad_transform(GemmM, PadGemmM)),
426
427 // B: input tensor
428 const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
429 in_grid_desc,
431 make_pad_transform(Wi, InLeftPadW, InRightPadW),
432 make_pass_through_transform(NumGroupsToMerge),
436
437 const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
438 in_n_hip_wip_c_grid_desc,
441 make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
442 make_pass_through_transform(NumGroupsToMerge),
446
447 const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor(
448 in_n_y_ho_x_wo_c_grid_desc,
449 make_tuple(make_merge_transform(make_tuple(X, NumGroupsToMerge, C)),
453
454 const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
455 in_gemmktotal_gemmn_grid_desc,
456 make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
460
461 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
462 in_gemmkpad_gemmn_grid_desc,
463 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
464 make_right_pad_transform(GemmN, PadGemmN)),
467
468 // Padd
469 const auto wei_gemmm_gemmn_pad_grid_desc =
470 transform_tensor_descriptor(wei_grid_desc,
471 make_tuple(make_right_pad_transform(GemmM, PadGemmM),
472 make_right_pad_transform(GemmN, PadGemmN)),
475
476 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
477 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
478 wei_gemmm_gemmn_pad_grid_desc);
479 }
480
481 } // function end
482
483 template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
485 const index_t N,
486 const index_t K,
487 const index_t C,
488 const std::array<index_t, NDimSpatial>& input_spatial_lengths,
489 const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
490 const std::array<index_t, NDimSpatial>& output_spatial_lengths,
491 const std::array<index_t, NDimSpatial + 3>& input_strides,
492 const std::array<index_t, NDimSpatial + 3>& weights_strides,
493 const std::array<index_t, NDimSpatial + 3>& output_strides,
494 const std::array<index_t, NDimSpatial>& conv_filter_strides,
495 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
496 const std::array<index_t, NDimSpatial>& input_left_pads,
497 const std::array<index_t, NDimSpatial>& input_right_pads,
498 const index_t batch_k)
499 {
500 using namespace ck;
501
502 const index_t Hi = input_spatial_lengths[0];
503 const index_t Wi = input_spatial_lengths[1];
504
505 const index_t Ho = output_spatial_lengths[0];
506 const index_t Wo = output_spatial_lengths[1];
507
508 const index_t Y = filter_spatial_lengths[0];
509 const index_t X = filter_spatial_lengths[1];
510
511 const index_t ConvStrideH = conv_filter_strides[0];
512 const index_t ConvStrideW = conv_filter_strides[1];
513
514 const index_t ConvDilationH = conv_filter_dilations[0];
515 const index_t ConvDilationW = conv_filter_dilations[1];
516
517 const index_t InLeftPadH = input_left_pads[0];
518 const index_t InLeftPadW = input_left_pads[1];
519
520 const index_t InRightPadH = input_right_pads[0];
521 const index_t InRightPadW = input_right_pads[1];
522
523 const index_t GemmKTotal = N * Ho * Wo;
524 const index_t GemmM = K * NumGroupsToMerge;
525 const index_t GemmN = C * X * Y * NumGroupsToMerge;
526
527 const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
528 const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
529
530 const index_t GemmKBatch = batch_k;
531 const index_t GemmK0 =
532 math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
533 K0PerBlock;
534 const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
535
536 const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
537 const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
538 const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Y, X, C, weights_strides);
539
540 if constexpr(ConvBackwardWeightSpecialization ==
542 {
543 // A: output tensor
544 const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
545 out_grid_desc,
547 make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
548 make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))),
551
552 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
553 out_gemmkpad_gemmm_grid_desc,
554 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
555 make_right_pad_transform(GemmM, PadGemmM)),
558
559 // B: input tensor
560 const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
561 in_grid_desc,
563 make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
564 make_merge_transform(make_tuple(NumGroupsToMerge, GemmN / NumGroupsToMerge))),
567
568 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
569 in_gemmkpad_gemmn_grid_desc,
570 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
571 make_right_pad_transform(GemmN, PadGemmN)),
574
575 // Padd
576 const auto wei_gemmm_gemmn_pad_grid_desc =
577 transform_tensor_descriptor(wei_grid_desc,
578 make_tuple(make_right_pad_transform(GemmM, PadGemmM),
579 make_right_pad_transform(GemmN, PadGemmN)),
582
583 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
584 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
585 wei_gemmm_gemmn_pad_grid_desc);
586 }
587 else
588 {
589 // A: output tensor
590 const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
591 out_grid_desc,
593 make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
594 make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))),
597
598 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
599 out_gemmkpad_gemmm_grid_desc,
600 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
601 make_right_pad_transform(GemmM, PadGemmM)),
604
605 // B: input tensor
606 const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
607 in_grid_desc,
609 make_pad_transform(Hi, InLeftPadH, InRightPadH),
610 make_pad_transform(Wi, InLeftPadW, InRightPadW),
611 make_pass_through_transform(NumGroupsToMerge),
617
618 const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
619 in_n_hip_wip_c_grid_desc,
622 make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
623 make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
624 make_pass_through_transform(NumGroupsToMerge),
631 Sequence<5>{},
632 Sequence<6>{}));
633
634 const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor(
635 in_n_y_ho_x_wo_c_grid_desc,
636 make_tuple(make_merge_transform(make_tuple(Y, X, NumGroupsToMerge, C)),
637 make_merge_transform(make_tuple(N, Ho, Wo))),
640
641 const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
642 in_gemmktotal_gemmn_grid_desc,
643 make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
647
648 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
649 in_gemmkpad_gemmn_grid_desc,
650 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
651 make_right_pad_transform(GemmN, PadGemmN)),
654
655 // Padd
656 const auto wei_gemmm_gemmn_pad_grid_desc =
657 transform_tensor_descriptor(wei_grid_desc,
658 make_tuple(make_right_pad_transform(GemmM, PadGemmM),
659 make_right_pad_transform(GemmN, PadGemmN)),
662
663 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
664 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
665 wei_gemmm_gemmn_pad_grid_desc);
666 }
667 }
668
669 template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
671 const index_t N,
672 const index_t K,
673 const index_t C,
674 const std::array<index_t, NDimSpatial>& input_spatial_lengths,
675 const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
676 const std::array<index_t, NDimSpatial>& output_spatial_lengths,
677 const std::array<index_t, NDimSpatial + 3>& input_strides,
678 const std::array<index_t, NDimSpatial + 3>& weights_strides,
679 const std::array<index_t, NDimSpatial + 3>& output_strides,
680 const std::array<index_t, NDimSpatial>& conv_filter_strides,
681 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
682 const std::array<index_t, NDimSpatial>& input_left_pads,
683 const std::array<index_t, NDimSpatial>& input_right_pads,
684 const index_t batch_k)
685 {
686 using namespace ck;
687
688 const index_t Di = input_spatial_lengths[0];
689 const index_t Hi = input_spatial_lengths[1];
690 const index_t Wi = input_spatial_lengths[2];
691
692 const index_t Do = output_spatial_lengths[0];
693 const index_t Ho = output_spatial_lengths[1];
694 const index_t Wo = output_spatial_lengths[2];
695
696 const index_t Z = filter_spatial_lengths[0];
697 const index_t Y = filter_spatial_lengths[1];
698 const index_t X = filter_spatial_lengths[2];
699
700 const index_t ConvStrideD = conv_filter_strides[0];
701 const index_t ConvStrideH = conv_filter_strides[1];
702 const index_t ConvStrideW = conv_filter_strides[2];
703
704 const index_t ConvDilationD = conv_filter_dilations[0];
705 const index_t ConvDilationH = conv_filter_dilations[1];
706 const index_t ConvDilationW = conv_filter_dilations[2];
707
708 const index_t InLeftPadD = input_left_pads[0];
709 const index_t InLeftPadH = input_left_pads[1];
710 const index_t InLeftPadW = input_left_pads[2];
711
712 const index_t InRightPadD = input_right_pads[0];
713 const index_t InRightPadH = input_right_pads[1];
714 const index_t InRightPadW = input_right_pads[2];
715
716 const index_t GemmKTotal = N * Do * Ho * Wo;
717 const index_t GemmM = K * NumGroupsToMerge;
718 const index_t GemmN = C * Z * X * Y * NumGroupsToMerge;
719
720 const auto PadGemmM = GemmM % MPerBlock == 0 ? 0 : MPerBlock - GemmM % MPerBlock;
721 const auto PadGemmN = GemmN % NPerBlock == 0 ? 0 : NPerBlock - GemmN % NPerBlock;
722
723 const index_t GemmKBatch = batch_k;
724 const index_t GemmK0 =
725 math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
726 K0PerBlock;
727 const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
728
729 const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
730 const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
731 const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Z, Y, X, C, weights_strides);
732
733 if constexpr(ConvBackwardWeightSpecialization ==
735 {
736 // A: output tensor
737 const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
738 out_grid_desc,
740 make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
741 make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))),
744
745 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
746 out_gemmkpad_gemmm_grid_desc,
747 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
748 make_right_pad_transform(GemmM, PadGemmM)),
751
752 // B: input tensor
753 const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
754 in_grid_desc,
756 make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
757 make_merge_transform(make_tuple(NumGroupsToMerge, GemmN / NumGroupsToMerge))),
760
761 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
762 in_gemmkpad_gemmn_grid_desc,
763 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
764 make_right_pad_transform(GemmN, PadGemmN)),
767
768 // Padd
769 const auto wei_gemmm_gemmn_pad_grid_desc =
770 transform_tensor_descriptor(wei_grid_desc,
771 make_tuple(make_right_pad_transform(GemmM, PadGemmM),
772 make_right_pad_transform(GemmN, PadGemmN)),
775
776 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
777 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
778 wei_gemmm_gemmn_pad_grid_desc);
779 }
780 else
781 {
782 // A: output tensor
783 const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
784 out_grid_desc,
786 make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
787 make_merge_transform(make_tuple(NumGroupsToMerge, GemmM / NumGroupsToMerge))),
790
791 const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
792 out_gemmkpad_gemmm_grid_desc,
793 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
794 make_right_pad_transform(GemmM, PadGemmM)),
797
798 // B: input tensor
799 const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
800 in_grid_desc,
802 make_pad_transform(Di, InLeftPadD, InRightPadD),
803 make_pad_transform(Hi, InLeftPadH, InRightPadH),
804 make_pad_transform(Wi, InLeftPadW, InRightPadW),
805 make_pass_through_transform(NumGroupsToMerge),
808 Sequence<1>{},
809 Sequence<2>{},
810 Sequence<3>{},
811 Sequence<4>{},
812 Sequence<5>{}),
814 Sequence<1>{},
815 Sequence<2>{},
816 Sequence<3>{},
817 Sequence<4>{},
818 Sequence<5>{}));
819
820 const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
821 in_n_dip_hip_wip_c_grid_desc,
824 make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
825 make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
826 make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
827 make_pass_through_transform(NumGroupsToMerge),
830 Sequence<1>{},
831 Sequence<2>{},
832 Sequence<3>{},
833 Sequence<4>{},
834 Sequence<5>{}),
839 Sequence<7>{},
840 Sequence<8>{}));
841
842 const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor(
843 in_n_z_do_y_ho_x_wo_c_grid_desc,
844 make_tuple(make_merge_transform(make_tuple(Z, Y, X, NumGroupsToMerge, C)),
845 make_merge_transform(make_tuple(N, Do, Ho, Wo))),
848
849 const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
850 in_gemmktotal_gemmn_grid_desc,
851 make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
855
856 const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
857 in_gemmkpad_gemmn_grid_desc,
858 make_tuple(make_unmerge_transform(make_tuple(GemmKBatch * GemmK0, GemmK1Number)),
859 make_right_pad_transform(GemmN, PadGemmN)),
862
863 // Padd
864 const auto wei_gemmm_gemmn_pad_grid_desc =
865 transform_tensor_descriptor(wei_grid_desc,
866 make_tuple(make_right_pad_transform(GemmM, PadGemmM),
867 make_right_pad_transform(GemmN, PadGemmN)),
870
871 return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
872 in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
873 wei_gemmm_gemmn_pad_grid_desc);
874 }
875 } // function end
876};
877
878} // namespace tensor_operation
879} // namespace ck
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
ConvolutionBackwardWeightSpecialization
Definition convolution_backward_weight_specialization.hpp:13
@ Filter1x1Stride1Pad0
Definition convolution_backward_weight_specialization.hpp:15
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__host__ __device__ constexpr auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:37
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
__host__ __device__ constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:191
Definition utility/sequence.hpp:43
Transform conv bwd weight to gemm v2.
Definition transform_conv_bwd_weight_to_gemm_v2.hpp:33
static constexpr auto make_wei_grid_desc(const index_t K, const index_t X, const index_t C, const std::array< index_t, NDimSpatial+3 > &weights_strides)
Definition transform_conv_bwd_weight_to_gemm_v2.hpp:78
static constexpr auto make_in_grid_desc(const index_t N, const index_t Hi, const index_t Wi, const index_t C, const std::array< index_t, NDimSpatial+3 > &input_strides)
Definition transform_conv_bwd_weight_to_gemm_v2.hpp:142
static constexpr auto I1
Definition transform_conv_bwd_weight_to_gemm_v2.hpp:35
static constexpr auto make_out_grid_desc(const index_t N, const index_t Do, const index_t Ho, const index_t Wo, const index_t K, const std::array< index_t, NDimSpatial+3 > &output_strides)
Definition transform_conv_bwd_weight_to_gemm_v2.hpp:219
static constexpr auto make_wei_grid_desc(const index_t K, const index_t Y, const index_t X, const index_t C, const std::array< index_t, NDimSpatial+3 > &weights_strides)
Definition transform_conv_bwd_weight_to_gemm_v2.hpp:169
static constexpr auto I0
Definition transform_conv_bwd_weight_to_gemm_v2.hpp:34
static constexpr auto make_out_grid_desc(const index_t N, const index_t Ho, const index_t Wo, const index_t K, const std::array< index_t, NDimSpatial+3 > &output_strides)
Definition transform_conv_bwd_weight_to_gemm_v2.hpp:127
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const index_t N, const index_t K, const index_t C, const std::array< index_t, NDimSpatial > &input_spatial_lengths, const std::array< index_t, NDimSpatial > &filter_spatial_lengths, const std::array< index_t, NDimSpatial > &output_spatial_lengths, const std::array< index_t, NDimSpatial+3 > &input_strides, const std::array< index_t, NDimSpatial+3 > &weights_strides, const std::array< index_t, NDimSpatial+3 > &output_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const index_t batch_k)
Definition transform_conv_bwd_weight_to_gemm_v2.hpp:313
static constexpr auto make_wei_grid_desc(const index_t K, const index_t Z, const index_t Y, const index_t X, const index_t C, const std::array< index_t, NDimSpatial+3 > &weights_strides)
Definition transform_conv_bwd_weight_to_gemm_v2.hpp:264
static constexpr auto make_out_grid_desc(const index_t N, const index_t Wo, const index_t K, const std::array< index_t, NDimSpatial+3 > &output_strides)
Definition transform_conv_bwd_weight_to_gemm_v2.hpp:39
static constexpr auto make_in_grid_desc(const index_t N, const index_t Di, const index_t Hi, const index_t Wi, const index_t C, const std::array< index_t, NDimSpatial+3 > &input_strides)
Definition transform_conv_bwd_weight_to_gemm_v2.hpp:235
static constexpr auto make_in_grid_desc(const index_t N, const index_t Wi, const index_t C, const std::array< index_t, NDimSpatial+3 > &input_strides)
Definition transform_conv_bwd_weight_to_gemm_v2.hpp:53