transform_conv_fwd_to_gemm.hpp Source File

transform_conv_fwd_to_gemm.hpp Source File#

Composable Kernel: transform_conv_fwd_to_gemm.hpp Source File
tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp
Go to the documentation of this file.
1
2// SPDX-License-Identifier: MIT
3// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
4
5#pragma once
6#include "ck_tile/core.hpp"
8namespace ck_tile {
9
10// ═══════════════════════════════════════════════════════════════════════
11// Split-Image Information Structure
12// ═══════════════════════════════════════════════════════════════════════
13// This structure holds all information needed to perform split-image
14// NOTE: SplitImageInfo struct deleted - was only used by deleted recursive split code
15// Current split-image implementation is in grouped_convolution_forward_invoker.hpp
16
17template <index_t NDimSpatial,
18 ConvolutionSpecialization ConvSpecialization,
19 index_t VectorSizeA,
20 index_t VectorSizeB,
21 index_t VectorSizeC,
22 index_t NumGroupsToMerge = 1,
23 bool SplitN = false,
24 typename ADataType = float,
25 typename CDataType = float,
26 typename IndexType = index_t>
28{
29 private:
30 static constexpr auto I0 = number<0>{};
31 static constexpr auto I1 = number<1>{};
32 static constexpr auto I2 = number<2>{};
33 static constexpr auto I3 = number<3>{};
34 static constexpr auto I4 = number<4>{};
35 static constexpr auto I5 = number<5>{};
36
37 // Unified memory limit constant for both Split-N and Split-Image
38 static constexpr long_index_t TwoGB = (long_index_t{1} << 31); // 2GB
39
40 template <typename ConvDimsType>
41 static long_index_t calculate_element_space_size_impl(const ConvDimsType& lengths,
42 const ConvDimsType& strides,
43 index_t i)
44 {
45 long_index_t acc = 1;
46 for(; i < (NDimSpatial + 3); i++)
47 {
48 acc +=
49 static_cast<long_index_t>(lengths[i] - I1) * static_cast<long_index_t>(strides[i]);
50 }
51
52 return acc;
53 }
54
55 template <typename ConvDimsType>
56 static IndexType GetSplitedNSize(const ConvDimsType& a_g_n_c_wis_lengths,
57 const ConvDimsType& c_g_n_k_wos_lengths)
58 {
59
60 // Calculate strides internally assuming contiguous memory layout
61 ConvDimsType a_g_n_c_wis_strides, c_g_n_k_wos_strides;
62 const index_t num_dims = a_g_n_c_wis_lengths.size();
63
64 // Calculate strides for input tensor (innermost to outermost)
65 a_g_n_c_wis_strides[num_dims - 1] = 1;
66 for(index_t i = num_dims - 2; i >= 0; i--)
67 {
68 a_g_n_c_wis_strides[i] = a_g_n_c_wis_strides[i + 1] * a_g_n_c_wis_lengths[i + 1];
69 }
70
71 // Calculate strides for output tensor
72 c_g_n_k_wos_strides[num_dims - 1] = 1;
73 for(index_t i = num_dims - 2; i >= 0; i--)
74 {
75 c_g_n_k_wos_strides[i] = c_g_n_k_wos_strides[i + 1] * c_g_n_k_wos_lengths[i + 1];
76 }
77
78 const long_index_t a_element_space_size =
79 calculate_element_space_size_impl(a_g_n_c_wis_lengths, a_g_n_c_wis_strides, I1);
80 const long_index_t c_element_space_size =
81 calculate_element_space_size_impl(c_g_n_k_wos_lengths, c_g_n_k_wos_strides, I1);
82 const long_index_t element_space_size = ck_tile::max(
83 a_element_space_size * sizeof(ADataType), c_element_space_size * sizeof(CDataType));
84
85 const IndexType N = a_g_n_c_wis_lengths[I1];
86
87 if(element_space_size > TwoGB)
88 {
89 // Minimum divisor of N to not exceed 2GB
90 const auto divisor = ck_tile::integer_divide_ceil(element_space_size, TwoGB);
91
92 if(divisor <= static_cast<double>(N))
93 {
94 // Find least divisor of N larger than element_space_size / TwoGB
95 // Iterate up to sqrt(N). There are no divisors above this value.
96 for(IndexType least_divisor = divisor; least_divisor * least_divisor <= N;
97 least_divisor++)
98 {
99 if(N % least_divisor == 0)
100 {
101 IndexType result = N / least_divisor;
102 return result;
103 }
104 }
105 // Not found, process one Convolution N per block
106 return 1;
107 }
108 else
109 {
110 // Split Convolution's N dimension into N workgroups. However
111 // this still might not result in sufficiently small tensor,
112 // but at least later on we could divide the image as well.
113 return 1;
114 }
115 }
116 else
117 {
118 // Split N is not needed.
119 return N;
120 }
121 }
122
123 public:
124 // Structure to hold split-image decision and factors
132
133 // Calculate split-image factors AFTER considering split-N
134 // Returns: should_split flag and optimal split factors for D, H, W dimensions
135 // Strategy: Hierarchical splitting with priority order D → H → W
136 // Dynamically increases split factors until memory fits below threshold
137 //
138 // NOTE: Layout validation should be done at the invoker level before calling this function
139 // Split-image only works with specific layouts:
140 // 1D: NWGC (input), GKXC (weight), NWGK (output)
141 // 2D: NHWGC (input), GKYXC (weight), NHWGK (output)
142 // 3D: NDHWGC (input), GKZYXC (weight), NDHWGK (output)
143 CK_TILE_HOST static SplitImageInfo GetSplitImageInfo(
144 index_t G, index_t N, index_t C, index_t K, index_t D_out, index_t H_out, index_t W_out)
145 {
146 SplitImageInfo info{false, 1, 1, 1};
147
148 // Estimate memory (simplified calculation)
149 // Use max of input and output tensor sizes
150 // Cast to long_index_t to prevent overflow during multiplication
151 const long_index_t input_elements =
152 static_cast<long_index_t>(N) * D_out * H_out * W_out * C * G;
153 const long_index_t output_elements =
154 static_cast<long_index_t>(N) * D_out * H_out * W_out * K * G;
155 const long_index_t input_bytes = input_elements * sizeof(ADataType);
156 const long_index_t output_bytes = output_elements * sizeof(CDataType);
157 const long_index_t max_tensor_bytes =
158 (input_bytes > output_bytes) ? input_bytes : output_bytes;
159
160 // Calculate effective N after split-N (simplified - assume worst case N=1)
161 index_t effective_N = 1;
162 if(max_tensor_bytes > TwoGB && N > 1)
163 {
164 // Split-N will reduce to approximately N=1 per launch
165 effective_N = 1;
166 }
167 else
168 {
169 effective_N = N;
170 }
171
172 // Check if split-image is needed
173 auto calc_memory = [&](index_t d_split, index_t h_split, index_t w_split) -> long_index_t {
174 index_t d_piece = D_out / d_split;
175 index_t h_piece = H_out / h_split;
176 index_t w_piece = W_out / w_split;
177 // Cast to long_index_t to prevent overflow
178 return static_cast<long_index_t>(effective_N) * d_piece * h_piece * w_piece * K * G *
179 sizeof(CDataType);
180 };
181
182 // Calculate memory after split-N with no spatial split
183 const long_index_t memory_after_split_n = calc_memory(1, 1, 1);
184
185 // Check if split-image is needed
186 if(memory_after_split_n <= TwoGB)
187 {
188 info.should_split = false;
189 return info;
190 }
191
192 // Split-image is needed - use hierarchical priority: D → H → W
193 info.should_split = true;
194
195 // Hierarchical splitting strategy:
196 // 1D: Split W until below threshold
197 // 2D: Split H first, if still too large then split W
198 // 3D: Split D first, then H, then W
199
200 // IMPORTANT: Maximum 64 pieces total (hardcoded array limit in invoker)
201 constexpr index_t MAX_TOTAL_PIECES = 64;
202
203 // Start with no split
204 info.num_d_pieces = 1;
205 info.num_h_pieces = 1;
206 info.num_w_pieces = 1;
207
208 // Try splitting D first (for 3D)
209 if(D_out > 1)
210 {
211 index_t max_d_split = (D_out < MAX_TOTAL_PIECES) ? D_out : MAX_TOTAL_PIECES;
212 for(index_t d_split = 2; d_split <= max_d_split; d_split++)
213 {
214 info.num_d_pieces = d_split;
215 if(calc_memory(d_split, 1, 1) <= TwoGB)
216 {
217 return info; // D split alone is sufficient
218 }
219 }
220 // D split maxed out, try H next
221 }
222
223 // Try splitting H (for 2D/3D)
224 if(H_out > 1)
225 {
226 index_t max_h_split = MAX_TOTAL_PIECES / info.num_d_pieces;
227 max_h_split = (H_out < max_h_split) ? H_out : max_h_split;
228
229 for(index_t h_split = 2; h_split <= max_h_split; h_split++)
230 {
231 info.num_h_pieces = h_split;
232 if(calc_memory(info.num_d_pieces, h_split, 1) <= TwoGB)
233 {
234 return info; // D+H split is sufficient
235 }
236 }
237 // H split maxed out, try W next
238 }
239
240 // Try splitting W (for 1D/2D/3D)
241 index_t max_w_split = MAX_TOTAL_PIECES / (info.num_d_pieces * info.num_h_pieces);
242 max_w_split = (W_out < max_w_split) ? W_out : max_w_split;
243
244 for(index_t w_split = 2; w_split <= max_w_split; w_split++)
245 {
246 info.num_w_pieces = w_split;
247 if(calc_memory(info.num_d_pieces, info.num_h_pieces, w_split) <= TwoGB)
248 {
249 return info; // D+H+W split is sufficient
250 }
251 }
252
253 // If we reach here, even maximum split doesn't fit
254 // Use maximum allowed split as best effort (capped at 64 total pieces)
255 info.num_d_pieces = (D_out < 4) ? D_out : 4; // Cap at 4
256 info.num_h_pieces = (H_out < 4) ? H_out : 4; // Cap at 4
257 info.num_w_pieces = (W_out < 4) ? W_out : 4; // Cap at 4 (max 4×4×4=64)
258
259 return info;
260 }
261
262 public:
263 // Public getter methods for Split-N support
264 CK_TILE_HOST constexpr IndexType GetN() const { return N_; }
265 CK_TILE_HOST constexpr IndexType GetOriginalN() const { return original_N_; }
266
268
269 template <typename TransformConvFwdToGemmBase>
271 TransformConvFwdToGemm(const TransformConvFwdToGemmBase& transform_conv_fwd_to_gemm_base)
272 : G_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.G_)},
273 N_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.N_)},
274 original_N_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.original_N_)},
275 Di_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Di_)},
276 Hi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Hi_)},
277 Wi_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Wi_)},
278 Do_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Do_)},
279 Ho_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Ho_)},
280 Wo_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Wo_)},
281 Z_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Z_)},
282 Y_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.Y_)},
283 X_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.X_)},
284 K_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.K_)},
285 C_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.C_)},
286 ConvStrideD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideD_)},
287 ConvStrideH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideH_)},
288 ConvStrideW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvStrideW_)},
289 ConvDilationD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationD_)},
290 ConvDilationH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationH_)},
291 ConvDilationW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ConvDilationW_)},
292 InLeftPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadD_)},
293 InLeftPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadH_)},
294 InLeftPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InLeftPadW_)},
295 InRightPadD_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadD_)},
296 InRightPadH_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadH_)},
297 InRightPadW_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.InRightPadW_)},
298 ZYX_{static_cast<IndexType>(transform_conv_fwd_to_gemm_base.ZYX_)}
299 {
300 }
301
302 template <typename ConvDimsType,
303 typename ConvSpatialDimsType,
304 index_t NDim = NDimSpatial,
305 typename std::enable_if<NDim == 1, bool>::type = false>
306 CK_TILE_HOST TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
307 const ConvDimsType& b_g_k_c_xs_lengths,
308 const ConvDimsType& c_g_n_k_wos_lengths,
309 const ConvSpatialDimsType& conv_filter_strides,
310 const ConvSpatialDimsType& conv_filter_dilations,
311 const ConvSpatialDimsType& input_left_pads,
312 const ConvSpatialDimsType& input_right_pads)
313 : G_{a_g_n_c_wis_lengths[I0]},
314 Di_{I1},
315 Hi_{I1},
316 Wi_{a_g_n_c_wis_lengths[I3]},
317 Do_{I1},
318 Ho_{I1},
319 Wo_{c_g_n_k_wos_lengths[I3]},
320 Z_{I1},
321 Y_{I1},
322 X_{b_g_k_c_xs_lengths[I3]},
323 K_{c_g_n_k_wos_lengths[I2]},
324 C_{b_g_k_c_xs_lengths[I2]},
325 ConvStrideD_{I1},
326 ConvStrideH_{I1},
327 ConvStrideW_{conv_filter_strides[I0]},
328 ConvDilationD_{I1},
329 ConvDilationH_{I1},
330 ConvDilationW_{conv_filter_dilations[I0]},
331 InLeftPadD_{I0},
332 InLeftPadH_{I0},
333 InLeftPadW_{input_left_pads[I0]},
334 InRightPadD_{I0},
335 InRightPadH_{I0},
336 InRightPadW_{input_right_pads[I0]},
337 ZYX_{X_}
338 {
339 static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
340 std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
341 static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
342 std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
343
344 // Store original N and initialize N_
345 original_N_ = N_ = c_g_n_k_wos_lengths[I1];
346
347 if constexpr(SplitN)
348 {
349 N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
350 }
351 }
352
353 template <typename ConvDimsType,
354 typename ConvSpatialDimsType,
355 index_t NDim = NDimSpatial,
356 typename std::enable_if<NDim == 2, bool>::type = false>
357 CK_TILE_HOST TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
358 const ConvDimsType& b_g_k_c_xs_lengths,
359 const ConvDimsType& c_g_n_k_wos_lengths,
360 const ConvSpatialDimsType& conv_filter_strides,
361 const ConvSpatialDimsType& conv_filter_dilations,
362 const ConvSpatialDimsType& input_left_pads,
363 const ConvSpatialDimsType& input_right_pads)
364 : G_{a_g_n_c_wis_lengths[I0]},
365 Di_{I1},
366 Hi_{a_g_n_c_wis_lengths[I3]},
367 Wi_{a_g_n_c_wis_lengths[I4]},
368 Do_{I1},
369 Ho_{c_g_n_k_wos_lengths[I3]},
370 Wo_{c_g_n_k_wos_lengths[I4]},
371 Z_{I1},
372 Y_{b_g_k_c_xs_lengths[I3]},
373 X_{b_g_k_c_xs_lengths[I4]},
374 K_{c_g_n_k_wos_lengths[I2]},
375 C_{b_g_k_c_xs_lengths[I2]},
376 ConvStrideD_{I1},
377 ConvStrideH_{conv_filter_strides[I0]},
378 ConvStrideW_{conv_filter_strides[I1]},
379 ConvDilationD_{I1},
380 ConvDilationH_{conv_filter_dilations[I0]},
381 ConvDilationW_{conv_filter_dilations[I1]},
382 InLeftPadD_{I0},
383 InLeftPadH_{input_left_pads[I0]},
384 InLeftPadW_{input_left_pads[I1]},
385 InRightPadD_{I0},
386 InRightPadH_{input_right_pads[I0]},
387 InRightPadW_{input_right_pads[I1]},
388 ZYX_{Y_ * X_}
389 {
390 static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
391 std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
392 static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
393 std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
394
395 // Store original N and initialize N_
396 original_N_ = N_ = c_g_n_k_wos_lengths[I1];
397
398 if constexpr(SplitN)
399 {
400 N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
401 }
402 }
403
404 template <typename ConvDimsType,
405 typename ConvSpatialDimsType,
406 index_t NDim = NDimSpatial,
407 typename std::enable_if<NDim == 3, bool>::type = false>
408 CK_TILE_HOST TransformConvFwdToGemm(const ConvDimsType& a_g_n_c_wis_lengths,
409 const ConvDimsType& b_g_k_c_xs_lengths,
410 const ConvDimsType& c_g_n_k_wos_lengths,
411 const ConvSpatialDimsType& conv_filter_strides,
412 const ConvSpatialDimsType& conv_filter_dilations,
413 const ConvSpatialDimsType& input_left_pads,
414 const ConvSpatialDimsType& input_right_pads)
415 : G_{a_g_n_c_wis_lengths[I0]},
416 Di_{a_g_n_c_wis_lengths[I3]},
417 Hi_{a_g_n_c_wis_lengths[I4]},
418 Wi_{a_g_n_c_wis_lengths[I5]},
419 Do_{c_g_n_k_wos_lengths[I3]},
420 Ho_{c_g_n_k_wos_lengths[I4]},
421 Wo_{c_g_n_k_wos_lengths[I5]},
422 Z_{b_g_k_c_xs_lengths[I3]},
423 Y_{b_g_k_c_xs_lengths[I4]},
424 X_{b_g_k_c_xs_lengths[I5]},
425 K_{c_g_n_k_wos_lengths[I2]},
426 C_{b_g_k_c_xs_lengths[I2]},
427 ConvStrideD_{conv_filter_strides[I0]},
428 ConvStrideH_{conv_filter_strides[I1]},
429 ConvStrideW_{conv_filter_strides[I2]},
430 ConvDilationD_{conv_filter_dilations[I0]},
431 ConvDilationH_{conv_filter_dilations[I1]},
432 ConvDilationW_{conv_filter_dilations[I2]},
433 InLeftPadD_{input_left_pads[I0]},
434 InLeftPadH_{input_left_pads[I1]},
435 InLeftPadW_{input_left_pads[I2]},
436 InRightPadD_{input_right_pads[I0]},
437 InRightPadH_{input_right_pads[I1]},
438 InRightPadW_{input_right_pads[I2]},
439 ZYX_{Z_ * Y_ * X_}
440 {
441 static_assert(std::is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
442 std::is_same_v<ConvSpatialDimsType, ck_tile::array<IndexType, NDimSpatial>>);
443 static_assert(std::is_same_v<ConvDimsType, std::array<IndexType, NDimSpatial + I3>> ||
444 std::is_same_v<ConvDimsType, ck_tile::array<IndexType, NDimSpatial + I3>>);
445
446 // Store original N and initialize N_
447 original_N_ = N_ = c_g_n_k_wos_lengths[I1];
448
449 if constexpr(SplitN)
450 {
451 N_ = GetSplitedNSize(a_g_n_c_wis_lengths, c_g_n_k_wos_lengths);
452 }
453 }
454
455 // Check if descriptors fit within memory threshold
456 // NOTE: Not currently used - split-image uses different approach in invoker
458 {
459 const long_index_t input_size = static_cast<long_index_t>(N_) * Di_ * Hi_ * Wi_ * C_;
460 const long_index_t output_size = static_cast<long_index_t>(N_) * Do_ * Ho_ * Wo_ * K_;
461
462 const long_index_t threshold = TwoGB / sizeof(ADataType);
463 return (input_size < threshold) && (output_size < threshold);
464 }
465
466 // TODO: implement ck_tile::tensor_layout::convolution that describe packed/strided dimemsion as
467 // properties
468 template <typename ALayout,
469 typename std::enable_if<NDimSpatial == 1 &&
470 std::is_same_v<ALayout, tensor_layout::convolution::NWGC>,
471 bool>::type = false>
473 {
474 IndexType WiStride_ = G_ * C_;
475 IndexType CStrideTensorA_ = 1;
476 IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
477 IndexType GStrideTensorA_ = C_;
478
479 if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
480 {
481 if constexpr(NumGroupsToMerge == 1)
482 {
483 const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
484 make_tuple(N_, Wo_, C_),
485 make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_),
487 I1);
489 in_gemmm_gemmk_desc,
494 }
495 else
496 {
497 const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
498 make_tuple(N_, Wo_, NumGroupsToMerge, C_),
499 make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_),
501 I1);
502
504 in_gemmm_groups_gemmk_desc,
505 make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
509 }
510 }
511 else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3)
512 {
513 if constexpr(NumGroupsToMerge == 1)
514 {
515
516 const auto in_n_wi_c_desc =
518 make_tuple(NStrideTensorA_, WiStride_),
520 I1);
521
522 const auto in_n_wip_c_desc = transform_tensor_descriptor(
523 in_n_wi_c_desc,
525 make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)),
528
529 const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
530 in_n_wip_c_desc,
533 make_tuple(ConvDilationW_, ConvStrideW_))),
536
538 in_n_x_wo_c_desc,
543 }
544 else
545 {
546 const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
547 make_tuple(N_, Wi_, NumGroupsToMerge),
548 make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_),
550 I1);
551
552 const auto in_n_wip_c_desc = transform_tensor_descriptor(
553 in_n_wi_c_desc,
555 make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
556 make_pass_through_transform(NumGroupsToMerge)),
559
560 const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
561 in_n_wip_c_desc,
564 make_tuple(ConvDilationW_, ConvStrideW_)),
565 make_pass_through_transform(NumGroupsToMerge)),
568
570 in_n_x_wo_c_desc,
571 make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
575 }
576 }
577 else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0)
578 {
579 if constexpr(NumGroupsToMerge == 1)
580 {
581 const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
582 make_tuple(N_, Wi_, C_),
583 make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_),
585 I1);
586
587 const auto in_n_wo_c_desc = transform_tensor_descriptor(
588 in_n_wi_c_desc,
590 make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
594
596 in_n_wo_c_desc,
601 }
602 else
603 {
604 const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
605 make_tuple(N_, Wi_, NumGroupsToMerge, C_),
606 make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_),
608 I1);
609
610 const auto in_n_wo_c_desc = transform_tensor_descriptor(
611 in_n_wi_c_desc,
613 make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
614 make_pass_through_transform(NumGroupsToMerge),
618
620 in_n_wo_c_desc,
621 make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
625 }
626 }
627 else
628 {
629 if constexpr(NumGroupsToMerge == 1)
630 {
631 const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
632 make_tuple(N_, Wi_, C_),
633 make_tuple(NStrideTensorA_, WiStride_, CStrideTensorA_),
635 I1);
636
637 const auto in_n_wip_c_desc = transform_tensor_descriptor(
638 in_n_wi_c_desc,
640 make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
644
645 const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
646 in_n_wip_c_desc,
649 make_tuple(ConvDilationW_, ConvStrideW_)),
653
655 in_n_x_wo_c_desc,
660 }
661 else
662 {
663 const auto in_n_wi_c_desc = make_naive_tensor_descriptor(
664 make_tuple(N_, Wi_, NumGroupsToMerge, C_),
665 make_tuple(NStrideTensorA_, WiStride_, GStrideTensorA_, CStrideTensorA_),
667 I1);
668
669 const auto in_n_wip_c_desc = transform_tensor_descriptor(
670 in_n_wi_c_desc,
672 make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
673 make_pass_through_transform(NumGroupsToMerge),
677
678 const auto in_n_x_wo_c_desc = transform_tensor_descriptor(
679 in_n_wip_c_desc,
682 make_tuple(ConvDilationW_, ConvStrideW_)),
683 make_pass_through_transform(NumGroupsToMerge),
687
689 in_n_x_wo_c_desc,
690 make_tuple(make_merge_transform(make_tuple(N_, Wo_, NumGroupsToMerge)),
694 }
695 }
696 }
697
698 template <typename ALayout,
699 typename std::enable_if<
700 NDimSpatial == 2 && std::is_same_v<ALayout, tensor_layout::convolution::NHWGC>,
701 bool>::type = false>
703
704 {
705 IndexType HiStride_ = Wi_ * G_ * C_;
706 IndexType WiStride_ = G_ * C_;
707 IndexType CStrideTensorA_ = 1;
708 IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
709 IndexType GStrideTensorA_ = C_;
710
711 if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
712 {
713 if constexpr(NumGroupsToMerge == 1)
714 {
715 const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
716 make_tuple(N_, Ho_, Wo_, C_),
717 make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_),
719 I1);
720
722 in_gemmm_gemmk_desc,
727 }
728 else
729 {
730 const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
731 make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, C_),
733 NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_),
735 I1);
736
738 in_gemmm_groups_gemmk_desc,
739 make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
743 }
744 }
745 else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3)
746 {
747 if constexpr(NumGroupsToMerge == 1)
748 {
749 const auto in_n_hi_wi_c_desc =
751 make_tuple(NStrideTensorA_, HiStride_, WiStride_),
753 I1);
754
755 const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
756 in_n_hi_wi_c_desc,
758 make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
759 make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)),
762
763 const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
764 in_n_hip_wip_c_desc,
767 make_tuple(ConvDilationH_, ConvStrideH_)),
769 make_tuple(ConvDilationW_, ConvStrideW_))),
772
774 in_n_y_ho_x_wo_c_desc,
779 }
780 else
781 {
782 const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
783 make_tuple(N_, Hi_, Wi_, NumGroupsToMerge),
784 make_tuple(NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_),
786 I1);
787
788 const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor(
789 in_n_hi_wi_groups_c_desc,
791 make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
792 make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
793 make_pass_through_transform(NumGroupsToMerge)),
796
797 const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor(
798 in_n_hip_wip_groups_c_desc,
801 make_tuple(ConvDilationH_, ConvStrideH_)),
803 make_tuple(ConvDilationW_, ConvStrideW_)),
804 make_pass_through_transform(NumGroupsToMerge)),
807
809 in_n_y_ho_x_wo_groups_c_desc,
810 make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
814 }
815 }
816 else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0)
817 {
818 if constexpr(NumGroupsToMerge == 1)
819 {
820 const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
821 make_tuple(N_, Hi_, Wi_, C_),
822 make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_),
824 I1);
825
826 const auto in_n_ho_wo_c_desc = transform_tensor_descriptor(
827 in_n_hi_wi_c_desc,
829 make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)),
830 make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
834
836 in_n_ho_wo_c_desc,
841 }
842 else
843 {
844 const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
845 make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
847 NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_),
849 I1);
850
851 const auto in_n_ho_wo_groups_c_desc = transform_tensor_descriptor(
852 in_n_hi_wi_groups_c_desc,
854 make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)),
855 make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
856 make_pass_through_transform(NumGroupsToMerge),
862
864 in_n_ho_wo_groups_c_desc,
865 make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
869 }
870 }
871 else
872 {
873 if constexpr(NumGroupsToMerge == 1)
874 {
875 const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
876 make_tuple(N_, Hi_, Wi_, C_),
877 make_tuple(NStrideTensorA_, HiStride_, WiStride_, CStrideTensorA_),
879 I1);
880
881 const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
882 in_n_hi_wi_c_desc,
884 make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
885 make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
889
890 const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
891 in_n_hip_wip_c_desc,
894 make_tuple(ConvDilationH_, ConvStrideH_)),
896 make_tuple(ConvDilationW_, ConvStrideW_)),
900
902 in_n_y_ho_x_wo_c_desc,
904 make_merge_transform(make_tuple(Y_, X_, C_))),
907 }
908 else
909 {
910
911 const auto in_n_hi_wi_groups_c_desc = make_naive_tensor_descriptor(
912 make_tuple(N_, Hi_, Wi_, NumGroupsToMerge, C_),
914 NStrideTensorA_, HiStride_, WiStride_, GStrideTensorA_, CStrideTensorA_),
916 I1);
917
918 const auto in_n_hip_wip_groups_c_desc = transform_tensor_descriptor(
919 in_n_hi_wi_groups_c_desc,
921 make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
922 make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
923 make_pass_through_transform(NumGroupsToMerge),
929
930 const auto in_n_y_ho_x_wo_groups_c_desc = transform_tensor_descriptor(
931 in_n_hip_wip_groups_c_desc,
934 make_tuple(ConvDilationH_, ConvStrideH_)),
936 make_tuple(ConvDilationW_, ConvStrideW_)),
937 make_pass_through_transform(NumGroupsToMerge),
944 sequence<5>{},
945 sequence<6>{}));
946
948 in_n_y_ho_x_wo_groups_c_desc,
949 make_tuple(make_merge_transform(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge)),
950 make_merge_transform(make_tuple(Y_, X_, C_))),
953 }
954 }
955 }
956
957 template <typename ALayout,
958 typename std::enable_if<
959 NDimSpatial == 3 && std::is_same_v<ALayout, tensor_layout::convolution::NDHWGC>,
960 bool>::type = false>
962
963 {
964 IndexType DiStride_ = Hi_ * Wi_ * G_ * C_;
965 IndexType HiStride_ = Wi_ * G_ * C_;
966 IndexType WiStride_ = G_ * C_;
967 IndexType CStrideTensorA_ = 1;
968 IndexType NStrideTensorA_ = Di_ * Hi_ * Wi_ * G_ * C_;
969 IndexType GStrideTensorA_ = C_;
970
971 if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Stride1Pad0)
972 {
973 if constexpr(NumGroupsToMerge == 1)
974 {
975 const auto in_gemmm_gemmk_desc = make_naive_tensor_descriptor(
976 make_tuple(N_, Do_, Ho_, Wo_, C_),
977 make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_),
979 I1);
980
982 in_gemmm_gemmk_desc,
983 make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
987 }
988 else
989 {
990 const auto in_gemmm_groups_gemmk_desc = make_naive_tensor_descriptor(
991 make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, C_),
992 make_tuple(NStrideTensorA_,
993 DiStride_,
994 HiStride_,
995 WiStride_,
996 GStrideTensorA_,
997 CStrideTensorA_),
999 I1);
1000
1002 in_gemmm_groups_gemmk_desc,
1003 make_tuple(
1004 make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
1008 }
1009 }
1010 else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3)
1011 {
1012 if constexpr(NumGroupsToMerge == 1)
1013 {
1014 const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
1015 make_tuple(N_, Di_, Hi_, Wi_),
1016 make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_),
1018 I1);
1019
1020 const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
1021 in_n_di_hi_wi_c_desc,
1023 make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
1024 make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
1025 make_pad_transform(Wi_, InLeftPadW_, InRightPadW_)),
1028
1029 const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
1030 in_n_hip_wip_c_desc,
1033 make_tuple(ConvDilationD_, ConvStrideD_)),
1035 make_tuple(ConvDilationH_, ConvStrideH_)),
1037 make_tuple(ConvDilationW_, ConvStrideW_))),
1039 make_tuple(
1041
1043 in_n_z_do_y_ho_x_wo_c_desc,
1044 make_tuple(
1045 make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
1049 }
1050 else
1051 {
1052 const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
1053 make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge),
1054 make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, GStrideTensorA_),
1056 I1);
1057
1058 const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
1059 in_n_di_hi_wi_c_desc,
1061 make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
1062 make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
1063 make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
1064 make_pass_through_transform(NumGroupsToMerge)),
1065 make_tuple(
1067 make_tuple(
1069
1070 const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
1071 in_n_hip_wip_c_desc,
1074 make_tuple(ConvDilationD_, ConvStrideD_)),
1076 make_tuple(ConvDilationH_, ConvStrideH_)),
1078 make_tuple(ConvDilationW_, ConvStrideW_)),
1079 make_pass_through_transform(NumGroupsToMerge)),
1080 make_tuple(
1086 sequence<7>{}));
1087
1089 in_n_z_do_y_ho_x_wo_c_desc,
1090 make_tuple(
1091 make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
1095 }
1096 }
1097 else if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter1x1Pad0)
1098 {
1099 if constexpr(NumGroupsToMerge == 1)
1100 {
1101 const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
1102 make_tuple(N_, Di_, Hi_, Wi_, C_),
1103 make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_),
1105 I1);
1106
1107 const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
1108 in_n_di_hi_wi_c_desc,
1110 make_embed_transform(make_tuple(Do_), make_tuple(ConvStrideD_)),
1111 make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)),
1112 make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
1114 make_tuple(
1116 make_tuple(
1118
1120 in_n_do_ho_wo_c_desc,
1121 make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
1125 }
1126 else
1127 {
1128 const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
1129 make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
1130 make_tuple(NStrideTensorA_,
1131 DiStride_,
1132 HiStride_,
1133 WiStride_,
1134 GStrideTensorA_,
1135 CStrideTensorA_),
1137 I1);
1138
1139 const auto in_n_do_ho_wo_c_desc = transform_tensor_descriptor(
1140 in_n_di_hi_wi_c_desc,
1142 make_embed_transform(make_tuple(Do_), make_tuple(ConvStrideD_)),
1143 make_embed_transform(make_tuple(Ho_), make_tuple(ConvStrideH_)),
1144 make_embed_transform(make_tuple(Wo_), make_tuple(ConvStrideW_)),
1145 make_pass_through_transform(NumGroupsToMerge),
1148 sequence<1>{},
1149 sequence<2>{},
1150 sequence<3>{},
1151 sequence<4>{},
1152 sequence<5>{}),
1154 sequence<1>{},
1155 sequence<2>{},
1156 sequence<3>{},
1157 sequence<4>{},
1158 sequence<5>{}));
1159
1161 in_n_do_ho_wo_c_desc,
1162 make_tuple(
1163 make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
1167 }
1168 }
1169 else
1170 {
1171 if constexpr(NumGroupsToMerge == 1)
1172 {
1173 const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
1174 make_tuple(N_, Di_, Hi_, Wi_, C_),
1175 make_tuple(NStrideTensorA_, DiStride_, HiStride_, WiStride_, CStrideTensorA_),
1177 I1);
1178
1179 const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
1180 in_n_di_hi_wi_c_desc,
1182 make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
1183 make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
1184 make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
1186 make_tuple(
1188 make_tuple(
1190
1191 const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
1192 in_n_hip_wip_c_desc,
1195 make_tuple(ConvDilationD_, ConvStrideD_)),
1197 make_tuple(ConvDilationH_, ConvStrideH_)),
1199 make_tuple(ConvDilationW_, ConvStrideW_)),
1201 make_tuple(
1207 sequence<7>{}));
1208
1210 in_n_z_do_y_ho_x_wo_c_desc,
1211 make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
1212 make_merge_transform(make_tuple(Z_, Y_, X_, C_))),
1215 }
1216 else
1217 {
1218 const auto in_n_di_hi_wi_c_desc = make_naive_tensor_descriptor(
1219 make_tuple(N_, Di_, Hi_, Wi_, NumGroupsToMerge, C_),
1220 make_tuple(NStrideTensorA_,
1221 DiStride_,
1222 HiStride_,
1223 WiStride_,
1224 GStrideTensorA_,
1225 CStrideTensorA_),
1227 I1);
1228
1229 const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
1230 in_n_di_hi_wi_c_desc,
1232 make_pad_transform(Di_, InLeftPadD_, InRightPadD_),
1233 make_pad_transform(Hi_, InLeftPadH_, InRightPadH_),
1234 make_pad_transform(Wi_, InLeftPadW_, InRightPadW_),
1235 make_pass_through_transform(NumGroupsToMerge),
1238 sequence<1>{},
1239 sequence<2>{},
1240 sequence<3>{},
1241 sequence<4>{},
1242 sequence<5>{}),
1244 sequence<1>{},
1245 sequence<2>{},
1246 sequence<3>{},
1247 sequence<4>{},
1248 sequence<5>{}));
1249
1250 const auto in_n_z_do_y_ho_x_wo_c_desc = transform_tensor_descriptor(
1251 in_n_hip_wip_c_desc,
1254 make_tuple(ConvDilationD_, ConvStrideD_)),
1256 make_tuple(ConvDilationH_, ConvStrideH_)),
1258 make_tuple(ConvDilationW_, ConvStrideW_)),
1259 make_pass_through_transform(NumGroupsToMerge),
1262 sequence<1>{},
1263 sequence<2>{},
1264 sequence<3>{},
1265 sequence<4>{},
1266 sequence<5>{}),
1271 sequence<7>{},
1272 sequence<8>{}));
1273
1275 in_n_z_do_y_ho_x_wo_c_desc,
1276 make_tuple(
1277 make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge)),
1278 make_merge_transform(make_tuple(Z_, Y_, X_, C_))),
1281 }
1282 }
1283 }
1284
1285 template <
1286 typename BLayout,
1287 typename std::enable_if<std::is_same_v<BLayout, tensor_layout::convolution::GKXC> ||
1288 std::is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
1289 std::is_same_v<BLayout, tensor_layout::convolution::GKZYXC>,
1290 bool>::type = false>
1292 {
1293 IndexType CStrideTensorB_ = 1;
1294 IndexType KStrideTensorB_ = Z_ * Y_ * X_ * C_;
1295 IndexType GStrideTensorB_ = K_ * Z_ * Y_ * X_ * C_;
1296
1297 if constexpr(ConvSpecialization == ConvolutionSpecialization::Filter3x3)
1298 {
1299 using FilterSizeNumType =
1300 std::conditional_t<NDimSpatial == 1,
1301 number<3>,
1302 std::conditional_t<NDimSpatial == 2, number<9>, number<27>>>;
1303
1304 if constexpr(NumGroupsToMerge == 1)
1305 {
1306 return make_naive_tensor_descriptor(make_tuple(K_, FilterSizeNumType{}),
1307 make_tuple(FilterSizeNumType{}, I1),
1309 I1);
1310 }
1311 else
1312 {
1313
1314 const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor(
1315 make_tuple(K_, NumGroupsToMerge, FilterSizeNumType{}),
1316 make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_),
1318 I1);
1320 wei_gemmn_groups_gemmk_desc,
1321 make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)),
1322 make_pass_through_transform(FilterSizeNumType{})),
1325 }
1326 }
1327 else
1328 {
1329 if constexpr(NumGroupsToMerge == 1)
1330 {
1331 return make_naive_tensor_descriptor(make_tuple(K_, ZYX_ * C_),
1332 make_tuple(ZYX_ * C_, I1),
1334 I1);
1335 }
1336 else
1337 {
1338 const auto wei_gemmn_groups_gemmk_desc = make_naive_tensor_descriptor(
1339 make_tuple(K_, NumGroupsToMerge, ZYX_ * C_),
1340 make_tuple(KStrideTensorB_, GStrideTensorB_, CStrideTensorB_),
1342 I1);
1344 wei_gemmn_groups_gemmk_desc,
1345 make_tuple(make_merge_transform(make_tuple(K_, NumGroupsToMerge)),
1346 make_pass_through_transform(ZYX_ * C_)),
1349 }
1350 }
1351 }
1352
1353 template <typename CLayout,
1354 index_t NDimSp = NDimSpatial,
1355 typename std::enable_if<NDimSp == 1 &&
1356 std::is_same_v<CLayout, tensor_layout::convolution::NWGK>,
1357 bool>::type = false>
1359 {
1360 IndexType WoStride_ = G_ * K_;
1361 IndexType KStrideTensorC_ = 1;
1362 IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
1363 IndexType GStrideTensorC_ = K_;
1364
1365 const IndexType NDoHoWo = N_ * Wo_;
1366 if constexpr(NumGroupsToMerge == 1)
1367 {
1368 return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
1369 make_tuple(WoStride_, KStrideTensorC_),
1371 I1);
1372 }
1373 else
1374 {
1375 const auto nhwo_groups_k_1_desc = make_naive_tensor_descriptor(
1376 make_tuple(N_, Wo_, NumGroupsToMerge, K_, 1),
1377 make_tuple(
1378 NStrideTensorC_, WoStride_, GStrideTensorC_, KStrideTensorC_, GStrideTensorC_),
1380 I1);
1381 // Padd 1 to NumGroupsToMerge
1382 const auto padded_desc = transform_tensor_descriptor(
1383 nhwo_groups_k_1_desc,
1385 make_pass_through_transform(NumGroupsToMerge),
1387 make_pad_transform(1, 0, NumGroupsToMerge - 1)),
1390 // We need only matrices from diagonal. X_or returns 0 for the same
1391 // values. So if matrices is not on diagonal then it will be stored in padding.
1392 // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
1393 static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
1394 NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
1395 NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
1396 const auto unmerged_padded_desc = transform_tensor_descriptor(
1397 padded_desc,
1399 make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
1403 // Merge To M, N
1405 unmerged_padded_desc,
1406 make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
1407 make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
1410 }
1411 }
1412
1413 template <typename CLayout,
1414 index_t NDimSp = NDimSpatial,
1415
1416 typename std::enable_if<
1417 NDimSp == 2 && std::is_same_v<CLayout, tensor_layout::convolution::NHWGK>,
1418 bool>::type = false>
1420 {
1421 IndexType HoStride_ = Wo_ * G_ * K_;
1422 IndexType WoStride_ = G_ * K_;
1423 IndexType KStrideTensorC_ = 1;
1424 IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
1425 IndexType GStrideTensorC_ = K_;
1426
1427 const IndexType NDoHoWo = N_ * Ho_ * Wo_;
1428 if constexpr(NumGroupsToMerge == 1)
1429 {
1430 return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
1431 make_tuple(WoStride_, KStrideTensorC_),
1433 I1);
1434 }
1435 else
1436 {
1437 const auto nhwo_groups_k_1_desc =
1438 make_naive_tensor_descriptor(make_tuple(N_, Ho_, Wo_, NumGroupsToMerge, K_, 1),
1439 make_tuple(NStrideTensorC_,
1440 HoStride_,
1441 WoStride_,
1442 GStrideTensorC_,
1443 KStrideTensorC_,
1444 GStrideTensorC_),
1446 I1);
1447 // Padd 1 to NumGroupsToMerge
1448 const auto padded_desc = transform_tensor_descriptor(
1449 nhwo_groups_k_1_desc,
1451 make_pass_through_transform(NumGroupsToMerge),
1453 make_pad_transform(1, 0, NumGroupsToMerge - 1)),
1456 // We need only matrices from diagonal. X_or returns 0 for the same
1457 // values. So if matrices is not on diagonal then it will be stored in padding.
1458 // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
1459 static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
1460 NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
1461 NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
1462 const auto unmerged_padded_desc = transform_tensor_descriptor(
1463 padded_desc,
1465 make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
1469 // Merge To M, N
1471 unmerged_padded_desc,
1472 make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
1473 make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
1476 }
1477 }
1478
1479 template <typename CLayout,
1480 index_t NDimSp = NDimSpatial,
1481 typename std::enable_if<
1482 NDimSp == 3 && std::is_same_v<CLayout, tensor_layout::convolution::NDHWGK>,
1483 bool>::type = false>
1485 {
1486 IndexType DoStride_ = Ho_ * Wo_ * G_ * K_;
1487 IndexType HoStride_ = Wo_ * G_ * K_;
1488 IndexType WoStride_ = G_ * K_;
1489 IndexType KStrideTensorC_ = 1;
1490 IndexType NStrideTensorC_ = Do_ * Ho_ * Wo_ * G_ * K_;
1491 IndexType GStrideTensorC_ = K_;
1492
1493 const IndexType NDoHoWo = N_ * Do_ * Ho_ * Wo_;
1494 if constexpr(NumGroupsToMerge == 1)
1495 {
1496 return make_naive_tensor_descriptor(make_tuple(NDoHoWo, K_),
1497 make_tuple(WoStride_, KStrideTensorC_),
1499 I1);
1500 }
1501 else
1502 {
1503 const auto nhwo_groups_k_1_desc =
1504 make_naive_tensor_descriptor(make_tuple(N_, Do_, Ho_, Wo_, NumGroupsToMerge, K_, 1),
1505 make_tuple(NStrideTensorC_,
1506 DoStride_,
1507 HoStride_,
1508 WoStride_,
1509 GStrideTensorC_,
1510 KStrideTensorC_,
1511 GStrideTensorC_),
1513 I1);
1514 // Padd 1 to NumGroupsToMerge
1515 const auto padded_desc = transform_tensor_descriptor(
1516 nhwo_groups_k_1_desc,
1517 make_tuple(make_merge_transform(make_tuple(N_, Do_, Ho_, Wo_)),
1518 make_pass_through_transform(NumGroupsToMerge),
1520 make_pad_transform(1, 0, NumGroupsToMerge - 1)),
1523 // We need only matrices from diagonal. X_or returns 0 for the same
1524 // values. So if matrices is not on diagonal then it will be stored in padding.
1525 // To avoid use of modulo after xor we assume that NumBatch to merge is power of 2.
1526 static_assert(NumGroupsToMerge == 1 || NumGroupsToMerge == 2 || NumGroupsToMerge == 4 ||
1527 NumGroupsToMerge == 8 || NumGroupsToMerge == 16 ||
1528 NumGroupsToMerge == 32 || NumGroupsToMerge == 64);
1529 const auto unmerged_padded_desc = transform_tensor_descriptor(
1530 padded_desc,
1532 make_xor_transform(make_tuple(NumGroupsToMerge, NumGroupsToMerge)),
1536 // Merge To M, N
1538 unmerged_padded_desc,
1539 make_tuple(make_merge_transform(make_tuple(NDoHoWo, NumGroupsToMerge)),
1540 make_merge_transform(make_tuple(K_, NumGroupsToMerge))),
1543 }
1544 }
1545
1546 // ═══════════════════════════════════════════════════════════════════════
1547 // Split-Image Calculation (AFTER Split-N)
1548 // ═══════════════════════════════════════════════════════════════════════
1549 // This method calculates split-image information using N_ (after Split-N).
1550 // This ensures correct offset calculations when both Split-N and Split-Image
1551 // are active simultaneously.
1552
1553 // NOTE: Deleted CalculateSplitImage() and LaunchWithRecursiveSplit() - dead code
1554 // Current split-image implementation is in grouped_convolution_forward_invoker.hpp
1555
1556 public:
1557 private:
1558 IndexType G_, N_, original_N_;
1559 IndexType Di_, Hi_, Wi_;
1560 IndexType Do_, Ho_, Wo_;
1561 IndexType Z_, Y_, X_;
1562 IndexType K_, C_;
1563 IndexType ConvStrideD_, ConvStrideH_, ConvStrideW_;
1564 IndexType ConvDilationD_, ConvDilationH_, ConvDilationW_;
1565 IndexType InLeftPadD_, InLeftPadH_, InLeftPadW_;
1566 IndexType InRightPadD_, InRightPadH_, InRightPadW_;
1567 IndexType ZYX_;
1568};
1569
1570} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
ConvolutionSpecialization
Definition convolution_specialization.hpp:11
@ Filter1x1Stride1Pad0
Definition convolution_specialization.hpp:14
@ Filter3x3
Definition convolution_specialization.hpp:15
@ Filter1x1Pad0
Definition convolution_specialization.hpp:13
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition coordinate_transform.hpp:1565
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
int64_t long_index_t
Definition integer.hpp:11
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1662
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST_DEVICE constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition coordinate_transform.hpp:1594
Definition tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp:126
index_t num_h_pieces
Definition tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp:129
index_t num_d_pieces
Definition tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp:128
bool should_split
Definition tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp:127
index_t num_w_pieces
Definition tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp:130
CK_TILE_HOST constexpr IndexType GetN() const
Definition tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp:264
CK_TILE_HOST bool AreDescriptorsSmallerThan2GB() const
Definition tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp:457
CK_TILE_HOST constexpr IndexType GetOriginalN() const
Definition tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp:265
CK_TILE_HOST TransformConvFwdToGemm(const ConvDimsType &a_g_n_c_wis_lengths, const ConvDimsType &b_g_k_c_xs_lengths, const ConvDimsType &c_g_n_k_wos_lengths, const ConvSpatialDimsType &conv_filter_strides, const ConvSpatialDimsType &conv_filter_dilations, const ConvSpatialDimsType &input_left_pads, const ConvSpatialDimsType &input_right_pads)
Definition tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp:306
static CK_TILE_HOST SplitImageInfo GetSplitImageInfo(index_t G, index_t N, index_t C, index_t K, index_t D_out, index_t H_out, index_t W_out)
Definition tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp:143
CK_TILE_HOST auto MakeADescriptor_M_K() const
Definition tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp:472
CK_TILE_HOST auto MakeBDescriptor_N_K() const
Definition tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp:1291
CK_TILE_HOST auto MakeCDescriptor_M_N() const
Definition tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp:1358
CK_TILE_HOST constexpr TransformConvFwdToGemm()
Definition tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp:267
CK_TILE_HOST TransformConvFwdToGemm(const TransformConvFwdToGemmBase &transform_conv_fwd_to_gemm_base)
Definition tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp:271
Definition tile/core/container/sequence.hpp:49