device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp Source File

device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp Source File#

Composable Kernel: device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp Source File
device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <functional>
7#include <iostream>
8#include <iterator>
9#include <numeric>
10#include <queue>
11#include <sstream>
12
27#ifdef CK_EXPERIMENTAL_BUILDER
28#include "ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp"
29#endif
30
31namespace ck {
32namespace tensor_operation {
33namespace device {
34
35namespace {
36
37template <typename GridwiseGemm,
38 index_t MaxGemmsNum,
39 typename GemmArgs,
40 typename AElementwiseOperation,
41 typename BElementwiseOperation,
42 typename CDEElementwiseOperation,
43 typename ComputePtrOffset,
44 bool HasMainKBlockLoop>
45__global__ void
46#if CK_USE_LAUNCH_BOUNDS
48#endif
49 kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle(
50 Array<GemmArgs, MaxGemmsNum> gemm_desc_kernel_args,
51 const index_t gemms_count,
52 const AElementwiseOperation a_element_op,
53 const BElementwiseOperation b_element_op,
54 const CDEElementwiseOperation c_element_op,
55 const ComputePtrOffset compute_ptr_offset_of_groups,
56 const ComputePtrOffset compute_ptr_offset_of_n)
57{
58#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__)
59 if constexpr(GridwiseGemm::template IsValidCompilationParameter<>())
60 {
61 __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
62
63 const index_t block_id_x = __builtin_amdgcn_readfirstlane(blockIdx.x);
64 const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y);
65 const index_t n_idx = __builtin_amdgcn_readfirstlane(blockIdx.z);
66
67 const long_index_t a_group_offset =
68 amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetAPtrOffset(g_idx));
69 const long_index_t b_group_offset =
70 amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetBPtrOffset(g_idx));
71 const auto& ds_group_offset = compute_ptr_offset_of_groups.GetDsPtrOffset(g_idx);
72 const long_index_t e_group_offset =
73 amd_wave_read_first_lane(compute_ptr_offset_of_groups.GetEPtrOffset(g_idx));
74
75 const long_index_t a_n_offset =
76 amd_wave_read_first_lane(compute_ptr_offset_of_n.GetAPtrOffset(n_idx));
77 const auto& ds_n_offset = compute_ptr_offset_of_n.GetDsPtrOffset(n_idx);
78 const long_index_t e_n_offset =
79 amd_wave_read_first_lane(compute_ptr_offset_of_n.GetEPtrOffset(n_idx));
80
81 index_t left = 0;
82 index_t right = gemms_count;
83 index_t group_id = index_t((left + right) / 2);
84 while((!(block_id_x >= gemm_desc_kernel_args[group_id].BlockStart_ &&
85 block_id_x < gemm_desc_kernel_args[group_id].BlockEnd_)) &&
86 left <= right)
87 {
88 if(block_id_x < gemm_desc_kernel_args[group_id].BlockStart_)
89 {
90 right = group_id;
91 }
92 else
93 {
94 left = group_id;
95 }
96 group_id = index_t((left + right) / 2);
97 }
98
99 using DsPointer = decltype(gemm_desc_kernel_args[Number<0>{}].ds_ptr_);
100 DsPointer p_ds_grid_grp;
101 static constexpr index_t NumDTensor = DsPointer::Size();
102 static_for<0, NumDTensor, 1>{}([&](auto i) {
103 p_ds_grid_grp(i) =
104 gemm_desc_kernel_args[group_id].ds_ptr_[i] + ds_group_offset[i] + ds_n_offset[i];
105 });
106
107 GridwiseGemm::template Run<HasMainKBlockLoop, InMemoryDataOperationEnum::Set>(
108 gemm_desc_kernel_args[group_id].a_ptr_ + a_group_offset + a_n_offset,
109 gemm_desc_kernel_args[group_id].b_ptr_ + b_group_offset,
110 p_ds_grid_grp,
111 gemm_desc_kernel_args[group_id].e_ptr_ + e_group_offset + e_n_offset,
112 p_shared,
113 a_element_op,
114 b_element_op,
115 c_element_op,
116 gemm_desc_kernel_args[group_id].a_grid_desc_ak0_m_ak1_,
117 gemm_desc_kernel_args[group_id].b_grid_desc_bk0_n_bk1_,
118 gemm_desc_kernel_args[group_id].ds_grid_desc_mblock_mperblock_nblock_nperblock_,
119 gemm_desc_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_,
120 gemm_desc_kernel_args[group_id].block_2_etile_map_);
121 }
122#else
123 ignore = gemm_desc_kernel_args;
124 ignore = gemms_count;
125 ignore = a_element_op;
126 ignore = b_element_op;
127 ignore = c_element_op;
128 ignore = compute_ptr_offset_of_groups;
129 ignore = compute_ptr_offset_of_n;
130#endif
131}
132
133} // namespace
134
135template <typename T>
136using is_tuple = decltype(std::declval<T&>().IsTuple());
137
138template <index_t NDimSpatial,
139 typename ALayout,
140 typename BLayout,
141 typename DsLayout,
142 typename ELayout,
143 typename ADataType,
144 typename BDataType,
145 typename AccDataType,
146 typename CShuffleDataType,
147 typename DsDataType,
148 typename EDataType,
149 typename AElementwiseOperation,
150 typename BElementwiseOperation,
151 typename CDEElementwiseOperation,
152 ConvolutionForwardSpecialization ConvForwardSpecialization,
153 GemmSpecialization GemmSpec,
154 index_t NumGemmKPrefetchStage,
155 index_t BlockSize,
156 index_t MPerBlock,
157 index_t NPerBlock,
158 index_t KPerBlock,
159 index_t AK1,
160 index_t BK1,
161 index_t MPerXDL,
162 index_t NPerXDL,
163 index_t MXdlPerWave,
164 index_t NXdlPerWave,
165 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
166 typename ABlockTransferThreadClusterArrangeOrder,
167 typename ABlockTransferSrcAccessOrder,
168 index_t ABlockTransferSrcVectorDim,
169 index_t ABlockTransferSrcScalarPerVector,
170 index_t ABlockTransferDstScalarPerVector_AK1,
171 index_t ABlockLdsExtraM,
172 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
173 typename BBlockTransferThreadClusterArrangeOrder,
174 typename BBlockTransferSrcAccessOrder,
175 index_t BBlockTransferSrcVectorDim,
176 index_t BBlockTransferSrcScalarPerVector,
177 index_t BBlockTransferDstScalarPerVector_BK1,
178 index_t BBlockLdsExtraN,
179 index_t CShuffleMXdlPerWavePerShuffle,
180 index_t CShuffleNXdlPerWavePerShuffle,
181 typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
182 index_t CDEBlockTransferScalarPerVector_NPerBlock,
183 typename AComputeDataType =
184 decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
185 Number<0>,
186 ADataType>()), // ComputeType is InputType by default (first
187 // in tuple for MultiAB), unpack if tuple was
188 // passed
189 typename BComputeDataType = AComputeDataType,
192 : public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
193 ALayout,
194 BLayout,
195 DsLayout,
196 ELayout,
197 ADataType,
198 BDataType,
199 DsDataType,
200 EDataType,
201 AElementwiseOperation,
202 BElementwiseOperation,
203 CDEElementwiseOperation,
204 AComputeDataType,
205 BComputeDataType>
206{
209 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
210 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
211
212 static constexpr index_t NumDTensor = DsDataType::Size();
213 static constexpr index_t MaxGemmsNum = 32;
214 static constexpr bool DoElementwiseBeforeCShuffle =
217
218 static constexpr auto I0 = Number<0>{};
219 static constexpr auto I1 = Number<1>{};
220 static constexpr auto I2 = Number<2>{};
221 static constexpr auto I3 = Number<3>{};
222
224 ConvForwardSpecialization,
225 true /*SplitN*/,
226 ADataType,
227 EDataType,
228 I1,
229 index_t>;
230
232 ConvForwardSpecialization,
233 true /*SplitN*/,
234 ADataType,
235 EDataType,
236 I1,
238
239 static constexpr auto matrix_padder =
240 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
241
242 template <typename ALay>
243 static auto
245 {
246 const auto in_gemmmraw_gemmkraw_desc =
247 conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
248
249 const auto in_gemmm_gemmk_desc =
250 matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
251
252 return in_gemmm_gemmk_desc;
253 }
254
255 template <typename BLay>
256 static auto
258 {
259 const auto wei_gemmnraw_gemmkraw_desc =
260 conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
261
262 const auto wei_gemmn_gemmk_desc =
263 matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
264
265 return wei_gemmn_gemmk_desc;
266 }
267
268 template <typename ELay>
269 static auto
271 {
272 const auto out_gemmmraw_gemmnraw_desc =
273 conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
274
275 const auto out_gemmm_gemmn_desc =
276 matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
277
278 return out_gemmm_gemmn_desc;
279 }
280
281 static auto
283 {
284 return generate_tuple(
285 [&](auto i) {
286 using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
287
288 return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
289 },
291 }
292
293 static auto CastDsPointers(const std::array<const void*, NumDTensor>& p_ds)
294 {
295 return generate_tuple(
296 [&](auto i) {
297 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
298 return static_cast<const DDataType*>(p_ds[i]);
299 },
301 }
302
303 using DsPointer = decltype(CastDsPointers(std::array<const void*, NumDTensor>{}));
304 // desc for problem definition
314
315 static auto
317 const ADataType* a_grid_ptr_base,
318 DsPointer ds_grid_ptr_base,
319 EDataType* c_grid_ptr_base)
320 {
321 // Max number of splits
322 // We need to use it to avoid infinity loop
323 constexpr index_t max_split_numbers = MaxGemmsNum / 2;
324 // Arrays to store transformers with smaller descs than 2GB
325 Array<ConvToGemmFwdTransformerIndexT, MaxGemmsNum> conv_to_gemm_transformers_arr;
327 Array<DsPointer, MaxGemmsNum> ds_grid_ptrs_arr;
328 Array<EDataType*, MaxGemmsNum> c_grid_ptrs_arr;
329 // Queue for spliting
330 std::queue<ConvToGemmFwdTransformerLongIndexT> conv_to_gemm_transformers_queue(
331 {conv_to_gemm_transformer_base});
332 std::queue<const ADataType*> a_grid_ptrs_queue({a_grid_ptr_base});
333 std::queue<DsPointer> ds_grid_ptrs_queue({ds_grid_ptr_base});
334 std::queue<EDataType*> c_grid_ptrs_queue({c_grid_ptr_base});
335
336 index_t gemms_number = 0;
337 index_t split_numbers = 0;
338 // Algorithm:
339 // While queue is not empty:
340 // 1. Get transformer from queue.
341 // 2. If descs are smaller than 2GB push to result array.
342 // 3. If descs are bigger than 2GB split into left and right transformer.
343 // and push the both into the queue.
344 while(!conv_to_gemm_transformers_queue.empty() && split_numbers < max_split_numbers &&
345 gemms_number < MaxGemmsNum)
346 {
347 // Get transformer from the queue
348 const auto& conv_to_gemm_transformer = conv_to_gemm_transformers_queue.front();
349 const ADataType* a_grid_ptr = a_grid_ptrs_queue.front();
350 DsPointer ds_grid_ptr = ds_grid_ptrs_queue.front();
351 EDataType* c_grid_ptr = c_grid_ptrs_queue.front();
352
353 // Check if convolution not exceed 2GB
354 if(conv_to_gemm_transformer.AreDescriptorsSmallerThan2GB())
355 {
356 // If yes, push into result array
357 conv_to_gemm_transformers_arr(gemms_number) =
358 ConvToGemmFwdTransformerIndexT{conv_to_gemm_transformer};
359 a_grid_ptrs_arr(gemms_number) = a_grid_ptr;
360 ds_grid_ptrs_arr(gemms_number) = ds_grid_ptr;
361 c_grid_ptrs_arr(gemms_number) = c_grid_ptr;
362 gemms_number++;
363 }
364 else
365 {
366 // If no, split into left and right convolutions
367 ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformers_left_part,
368 conv_to_gemm_transformers_right_part;
369 const ADataType* a_grid_right_ptr;
370 DsPointer ds_grid_right_ptr;
371 EDataType* c_grid_right_ptr;
372
373 ck::tie(conv_to_gemm_transformers_left_part,
374 conv_to_gemm_transformers_right_part,
375 a_grid_right_ptr,
376 ds_grid_right_ptr,
377 c_grid_right_ptr) =
378 conv_to_gemm_transformer.SplitConvProblem(a_grid_ptr, ds_grid_ptr, c_grid_ptr);
379
380 conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_left_part);
381 conv_to_gemm_transformers_queue.push(conv_to_gemm_transformers_right_part);
382 // Left offsets remain the same
383 a_grid_ptrs_queue.push(a_grid_ptr);
384 a_grid_ptrs_queue.push(a_grid_right_ptr);
385 ds_grid_ptrs_queue.push(ds_grid_ptr);
386 ds_grid_ptrs_queue.push(ds_grid_right_ptr);
387 c_grid_ptrs_queue.push(c_grid_ptr);
388 c_grid_ptrs_queue.push(c_grid_right_ptr);
389 split_numbers++;
390 }
391 // Remove from the queue
392 conv_to_gemm_transformers_queue.pop();
393 a_grid_ptrs_queue.pop();
394 ds_grid_ptrs_queue.pop();
395 c_grid_ptrs_queue.pop();
396 }
397
398 const bool is_split_valid = conv_to_gemm_transformers_queue.empty();
399
400 return ck::make_tuple(conv_to_gemm_transformers_arr,
401 a_grid_ptrs_arr,
402 ds_grid_ptrs_arr,
403 c_grid_ptrs_arr,
404 gemms_number,
405 is_split_valid);
406 }
407
408#define GridwiseGemmTemplateParameters \
409 ADataType, BDataType, AComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, \
410 AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
411 NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, \
412 NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, \
413 ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, \
414 ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, \
415 ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, \
416 BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, \
417 BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, \
418 BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, \
419 BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
420 CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
421 CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVersion::v1, \
422 AComputeDataType, DoElementwiseBeforeCShuffle
423 // Use appropriate gridwise gemm
424 template <index_t NXdlPerWave_>
428
429 // desc for blockwise copy
432 AGridDesc_M_K{}))>;
435 BGridDesc_N_K{}))>;
438 DsGridDesc_M_N{}))>;
441 EGridDesc_M_N{}))>;
442
443 // block-to-e-tile map
446 // Structure for each gemm(conv)
466
467 // Argument
468 struct Argument : public BaseArgument
469 {
470 template <typename GridwiseGemm, typename DsGridDesc_M_N_, typename EGridDescriptor_M_N_>
471 void init_gemm_args(const ADataType* a_ptr,
472 const BDataType* b_ptr,
473 DsPointer ds_ptr,
474 EDataType* e_ptr,
475 const AGridDesc_M_K& a_grid_desc_m_k,
476 const BGridDesc_N_K& b_grid_desc_n_k,
477 const DsGridDesc_M_N_& ds_grid_desc_m_n,
478 const EGridDescriptor_M_N_& e_grid_desc_m_n,
479 const Block2ETileMap& block_2_etile_map,
480 index_t BlockStart,
481 index_t BlockEnd)
482 {
483 if(GridwiseGemm::CheckValidity(a_grid_desc_m_k,
484 b_grid_desc_n_k,
485 ds_grid_desc_m_n,
486 e_grid_desc_m_n,
487 block_2_etile_map))
488 {
490 GemmArgs{a_ptr,
491 b_ptr,
492 ds_ptr,
493 e_ptr,
494 GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k),
495 GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k),
496 GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
497 ds_grid_desc_m_n),
498 GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
499 e_grid_desc_m_n),
500 block_2_etile_map,
501 BlockStart,
502 BlockEnd};
503
505 }
506 }
507 Argument(const void* p_a,
508 const void* p_b,
509 const std::array<const void*, NumDTensor>& p_ds,
510 void* p_e,
511 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
512 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
513 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
514 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
515 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
516 ds_g_n_k_wos_lengths,
517 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
518 ds_g_n_k_wos_strides,
519 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
520 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
521 const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
522 const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
523 const std::array<long_index_t, NDimSpatial>& input_left_pads,
524 const std::array<long_index_t, NDimSpatial>& input_right_pads,
525 const AElementwiseOperation& a_element_op,
526 const BElementwiseOperation& b_element_op,
527 const CDEElementwiseOperation& cde_element_op)
528 : num_group_{static_cast<index_t>(a_g_n_c_wis_lengths[0])},
531 a_element_op_{a_element_op},
532 b_element_op_{b_element_op},
533 cde_element_op_{cde_element_op},
534 a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
535 a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
536 b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
537 b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
538 ds_g_n_k_wos_lengths_{ds_g_n_k_wos_lengths},
539 ds_g_n_k_wos_strides_{ds_g_n_k_wos_strides},
540 e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
541 e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
542 conv_filter_strides_{conv_filter_strides},
543 conv_filter_dilations_{conv_filter_dilations},
544 input_left_pads_{input_left_pads},
545 input_right_pads_{input_right_pads}
546 {
547 // Perform grouped gemm, generate array of tranformer for convolution
548 Array<ConvToGemmFwdTransformerIndexT, MaxGemmsNum> conv_to_gemm_transformer_arr;
552
553 DsPointer p_ds_casted = CastDsPointers(p_ds);
554
555 ck::tie(conv_to_gemm_transformer_arr,
556 a_grid_ptrs,
557 ds_grid_ptrs,
558 c_grid_ptrs,
572 static_cast<const ADataType*>(p_a),
573 p_ds_casted,
574 static_cast<EDataType*>(p_e));
575
576 grid_size_ = 0;
578
580 {
581 // Create GemmArg for each gemm(conv)
582 for(index_t i = 0; i < gemms_count_; i++)
583 {
585 conv_to_gemm_transformer_arr[i])};
587 conv_to_gemm_transformer_arr[i])};
588 const auto e_grid_desc_m_n =
589 DeviceOp::MakeEGridDescriptor_M_N<ELayout>(conv_to_gemm_transformer_arr[i]);
590
591 const auto ds_grid_desc_m_n =
592 generate_tuple([&](auto) { return e_grid_desc_m_n; }, Number<NumDTensor>{});
593
594 const auto block_2_etile_map =
596
597 const index_t grid_size_grp =
598 block_2_etile_map.CalculateGridSize(e_grid_desc_m_n);
599
600 const index_t BlockStart = grid_size_;
601 const index_t BlockEnd = grid_size_ + grid_size_grp;
602
603 grid_size_ += grid_size_grp;
604
605 if(get_warp_size() == 64)
606 {
607 if constexpr(NXdlPerWave64 > 0)
608 {
609 init_gemm_args<GridwiseGemm64>(a_grid_ptrs[i],
610 static_cast<const BDataType*>(p_b),
611 ds_grid_ptrs[i],
612 c_grid_ptrs[i],
613 a_grid_desc_m_k,
614 b_grid_desc_n_k,
615 ds_grid_desc_m_n,
616 e_grid_desc_m_n,
617 block_2_etile_map,
618 BlockStart,
619 BlockEnd);
620 }
621 }
622 else
623 {
624 if constexpr(NXdlPerWave32 > 0)
625 {
626 init_gemm_args<GridwiseGemm32>(a_grid_ptrs[i],
627 static_cast<const BDataType*>(p_b),
628 ds_grid_ptrs[i],
629 c_grid_ptrs[i],
630 a_grid_desc_m_k,
631 b_grid_desc_n_k,
632 ds_grid_desc_m_n,
633 e_grid_desc_m_n,
634 block_2_etile_map,
635 BlockStart,
636 BlockEnd);
637 }
638 }
639 }
640 // N is the same for all convs
641 conv_N_per_block_ = static_cast<index_t>(conv_to_gemm_transformer_arr[I0].N_);
642 }
643
644 // Strides for G and N remain the same
645 compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides[0];
646 compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides[0];
647 compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides[0];
648
649 compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides[1] * conv_N_per_block_;
650 compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides[1] * conv_N_per_block_;
651
652 static_for<0, NumDTensor, 1>{}([&](auto i) {
654 compute_ptr_offset_of_n_.BatchStrideDs_(i) =
656 });
657 }
658
659 void Print() const
660 {
661 for(index_t i = 0; i < valid_gemms_count_; i++)
662 {
663 std::cout << "A[AK0, M, AK1]: " << gemm_desc_kernel_args_[i].a_grid_desc_ak0_m_ak1_
664 << std::endl;
665 std::cout << "B[BK0, N, BK1]: " << gemm_desc_kernel_args_[i].b_grid_desc_bk0_n_bk1_
666 << std::endl;
667 std::cout
668 << "E[MBlock, MPerBlock, NBlock, NPerBlock]: "
669 << gemm_desc_kernel_args_[i].e_grid_desc_mblock_mperblock_nblock_nperblock_
670 << std::endl;
671 }
672 }
673
676
678
682
684
685 // for computing batch offset
686 ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_groups_;
687 ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_n_;
688
689 // element-wise op
690 AElementwiseOperation a_element_op_;
691 BElementwiseOperation b_element_op_;
692 CDEElementwiseOperation cde_element_op_;
693
694 // for checking IsSupportedArgument()
695 std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
696 std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
697 std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
698 std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
699 std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
700 std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
701 std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
702 std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
703 std::array<long_index_t, NDimSpatial> conv_filter_strides_;
704 std::array<long_index_t, NDimSpatial> conv_filter_dilations_;
705 std::array<long_index_t, NDimSpatial> input_left_pads_;
706 std::array<long_index_t, NDimSpatial> input_right_pads_;
707 };
708
709 // Invoker
710 struct Invoker : public BaseInvoker
711 {
712
714 template <typename GridwiseGemm>
715 float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
716 {
717 if(stream_config.log_level_ > 0)
718 {
719 arg.Print();
720 }
721
722 const index_t num_workgroups_per_Conv_N =
724
725 const index_t gdx = arg.grid_size_;
726 const index_t gdy = arg.num_group_;
727 const index_t gdz = num_workgroups_per_Conv_N;
728
729 // K is constant for all gemms
730 const auto K = arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I0) *
731 arg.gemm_desc_kernel_args_[I0].a_grid_desc_ak0_m_ak1_.GetLength(I2);
732
733 auto launch_kernel = [&](auto has_main_k_block_loop) {
734 constexpr bool has_main_loop = has_main_k_block_loop.value;
735 const auto kernel = kernel_grouped_conv_fwd_multiple_d_grouped_gemm_xdl_cshuffle<
736 GridwiseGemm,
738 GemmArgs,
739 AElementwiseOperation,
740 BElementwiseOperation,
741 CDEElementwiseOperation,
742 ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
743 has_main_loop>;
744
745 return launch_and_time_kernel(stream_config,
746 kernel,
747 dim3(gdx, gdy, gdz),
748 dim3(BlockSize),
749 0,
751 arg.gemms_count_,
752 arg.a_element_op_,
753 arg.b_element_op_,
754 arg.cde_element_op_,
757 };
758
759 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
760 {
761 return launch_kernel(integral_constant<bool, true>{});
762 }
763 else
764 {
765 return launch_kernel(integral_constant<bool, false>{});
766 }
767 }
768
770
771 float Run(const BaseArgument* p_arg,
772 const StreamConfig& stream_config = StreamConfig{}) override
773 {
774 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
775 }
776 };
777
778 static bool IsSupportedArgument(const Argument& arg)
779 {
780 namespace ctc = tensor_layout::convolution;
781
782 const long_index_t K = arg.b_g_k_c_xs_lengths_[I1];
783 const long_index_t C = arg.b_g_k_c_xs_lengths_[I2];
784
785 bool ds_valid = true;
786 static_for<0, NumDTensor, 1>{}([&](auto i) {
787 for(int d = 0; d < NDimSpatial + I3; d++)
788 {
789 if(arg.ds_g_n_k_wos_strides_[i][d] != arg.e_g_n_k_wos_strides_[d])
790 {
791 ds_valid = false;
792 }
793 if(arg.ds_g_n_k_wos_lengths_[i][d] != arg.e_g_n_k_wos_lengths_[d])
794 {
795 ds_valid = false;
796 }
797 }
798
799 using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
800 static_assert(is_same_v<DDataType, EDataType>);
801 });
802
803 if(!ds_valid)
804 {
805 return false;
806 }
807
808 // Check if all descs are valid
809 if(!(arg.is_split_valid_ && arg.gemms_count_ == arg.valid_gemms_count_))
810 {
811 return false;
812 }
813 // check device
814 if(get_device_name() == "gfx908")
815 {
816 // FIXME: re-enable fp64 when SWDEV-335738 is fixed
818 {
819 return false;
820 }
821 }
823 {
824 return false;
825 }
828 {
829 if(!is_tf32_supported())
830 {
831 return false;
832 }
834 {
835 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
836 {
837 std::cout << "ComputeDataType for A and B should be same while using TF32"
838 << std::endl;
839 }
840 return false;
841 }
842 }
843 // check ConvolutionForwardSpecialization
844 if constexpr(ConvForwardSpecialization ==
846 {
847 // check if it's 1x1, stride=1 conv
848 for(index_t i = 0; i < NDimSpatial; ++i)
849 {
850 const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
851 const index_t ConvStride = arg.conv_filter_strides_[i];
852 const index_t LeftPad = arg.input_left_pads_[i];
853 const index_t RightPad = arg.input_right_pads_[i];
854
855 if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
856 {
857 return false;
858 }
859 }
860 }
861 else if constexpr(ConvForwardSpecialization ==
863 {
864 // check if it's 1x1 conv
865 for(index_t i = 0; i < NDimSpatial; ++i)
866 {
867 const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
868 const index_t LeftPad = arg.input_left_pads_[i];
869 const index_t RightPad = arg.input_right_pads_[i];
870
871 if(!(X == 1 && LeftPad == 0 && RightPad == 0))
872 {
873 return false;
874 }
875 }
876 }
877 else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3)
878 {
879 if(C != 1)
880 {
881 return false;
882 }
883 for(index_t i = 0; i < NDimSpatial; ++i)
884 {
885 const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3];
886
887 if(filter_spatial_dim != I3)
888 {
889 return false;
890 }
891 }
893 {
894 return false;
895 }
896 }
897
898 // check vector access of A
899 // FIXME: layout
905 {
906 // Check access per C
907 if(!(ABlockTransferSrcVectorDim == 2 && C % ABlockTransferSrcScalarPerVector == 0))
908 {
909 return false;
910 }
911 }
912 else
913 {
914 return false;
915 }
916
917 // check vector access of B
918 // FIXME: layout
924
925 {
926 if(!(BBlockTransferSrcVectorDim == 2 && C % BBlockTransferSrcScalarPerVector == 0))
927 {
928 return false;
929 }
930 }
931 else
932 {
933 return false;
934 }
935
936 // check vector access of E
942 {
943 if(!(K % CDEBlockTransferScalarPerVector_NPerBlock == 0))
944 {
945 return false;
946 }
947 }
948 else
949 {
950 return false;
951 }
952
953 return true;
954 }
955
956 bool IsSupportedArgument(const BaseArgument* p_arg) override
957 {
958 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
959 }
960
961 static auto MakeArgument(
962 const void* p_a,
963 const void* p_b,
964 const std::array<const void*, NumDTensor>& p_ds,
965 void* p_e,
966 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
967 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
968 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
969 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
970 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
971 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
972 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
973 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
974 const std::array<index_t, NDimSpatial>& conv_filter_strides,
975 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
976 const std::array<index_t, NDimSpatial>& input_left_pads,
977 const std::array<index_t, NDimSpatial>& input_right_pads,
978 const AElementwiseOperation& a_element_op,
979 const BElementwiseOperation& b_element_op,
980 const CDEElementwiseOperation& cde_element_op)
981 {
982 std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i64;
983 std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i64;
984 std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i64;
985 std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i64;
986 std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i64;
987 std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i64;
988 std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i64;
989 std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i64;
990 std::array<long_index_t, NDimSpatial> conv_filter_strides_i64;
991 std::array<long_index_t, NDimSpatial> conv_filter_dilations_i64;
992 std::array<long_index_t, NDimSpatial> input_left_pads_i64;
993 std::array<long_index_t, NDimSpatial> input_right_pads_i64;
994
995 array_convert(a_g_n_c_wis_lengths_i64, a_g_n_c_wis_lengths);
996 array_convert(a_g_n_c_wis_strides_i64, a_g_n_c_wis_strides);
997 array_convert(b_g_k_c_xs_lengths_i64, b_g_k_c_xs_lengths);
998 array_convert(b_g_k_c_xs_strides_i64, b_g_k_c_xs_strides);
999 for(index_t d = 0; d < NumDTensor; d++)
1000 {
1001 array_convert(ds_g_n_k_wos_lengths_i64[d], ds_g_n_k_wos_lengths[d]);
1002 array_convert(ds_g_n_k_wos_strides_i64[d], ds_g_n_k_wos_strides[d]);
1003 }
1004 array_convert(e_g_n_k_wos_lengths_i64, e_g_n_k_wos_lengths);
1005 array_convert(e_g_n_k_wos_strides_i64, e_g_n_k_wos_strides);
1006 array_convert(conv_filter_strides_i64, conv_filter_strides);
1007 array_convert(conv_filter_dilations_i64, conv_filter_dilations);
1008 array_convert(input_left_pads_i64, input_left_pads);
1009 array_convert(input_right_pads_i64, input_right_pads);
1010
1011 return Argument{p_a,
1012 p_b,
1013 p_ds,
1014 p_e,
1015 a_g_n_c_wis_lengths_i64,
1016 a_g_n_c_wis_strides_i64,
1017 b_g_k_c_xs_lengths_i64,
1018 b_g_k_c_xs_strides_i64,
1019 ds_g_n_k_wos_lengths_i64,
1020 ds_g_n_k_wos_strides_i64,
1021 e_g_n_k_wos_lengths_i64,
1022 e_g_n_k_wos_strides_i64,
1023 conv_filter_strides_i64,
1024 conv_filter_dilations_i64,
1025 input_left_pads_i64,
1026 input_right_pads_i64,
1027 a_element_op,
1028 b_element_op,
1029 cde_element_op};
1030 }
1031
1032 static auto
1033 MakeArgument(const void* p_a,
1034 const void* p_b,
1035 const std::array<const void*, NumDTensor>& p_ds,
1036 void* p_e,
1037 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1038 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1039 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1040 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1041 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
1042 ds_g_n_k_wos_lengths,
1043 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
1044 ds_g_n_k_wos_strides,
1045 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1046 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1047 const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
1048 const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
1049 const std::array<long_index_t, NDimSpatial>& input_left_pads,
1050 const std::array<long_index_t, NDimSpatial>& input_right_pads,
1051 const AElementwiseOperation& a_element_op,
1052 const BElementwiseOperation& b_element_op,
1053 const CDEElementwiseOperation& cde_element_op)
1054 {
1055 return Argument{p_a,
1056 p_b,
1057 p_ds,
1058 p_e,
1059 a_g_n_c_wis_lengths,
1060 a_g_n_c_wis_strides,
1061 b_g_k_c_xs_lengths,
1062 b_g_k_c_xs_strides,
1063 ds_g_n_k_wos_lengths,
1064 ds_g_n_k_wos_strides,
1065 e_g_n_k_wos_lengths,
1066 e_g_n_k_wos_strides,
1067 conv_filter_strides,
1068 conv_filter_dilations,
1069 input_left_pads,
1070 input_right_pads,
1071 a_element_op,
1072 b_element_op,
1073 cde_element_op};
1074 }
1075
1076 static auto MakeInvoker() { return Invoker{}; }
1077
1078 std::unique_ptr<BaseArgument> MakeArgumentPointer(
1079 const void* p_a,
1080 const void* p_b,
1081 const std::array<const void*, NumDTensor>& p_ds,
1082 void* p_e,
1083 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1084 const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1085 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1086 const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1087 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
1088 const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
1089 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1090 const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1091 const std::array<index_t, NDimSpatial>& conv_filter_strides,
1092 const std::array<index_t, NDimSpatial>& conv_filter_dilations,
1093 const std::array<index_t, NDimSpatial>& input_left_pads,
1094 const std::array<index_t, NDimSpatial>& input_right_pads,
1095 const AElementwiseOperation& a_element_op,
1096 const BElementwiseOperation& b_element_op,
1097 const CDEElementwiseOperation& cde_element_op) override
1098 {
1099
1100 std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i64;
1101 std::array<long_index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i64;
1102 std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i64;
1103 std::array<long_index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i64;
1104 std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i64;
1105 std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i64;
1106 std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i64;
1107 std::array<long_index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i64;
1108 std::array<long_index_t, NDimSpatial> conv_filter_strides_i64;
1109 std::array<long_index_t, NDimSpatial> conv_filter_dilations_i64;
1110 std::array<long_index_t, NDimSpatial> input_left_pads_i64;
1111 std::array<long_index_t, NDimSpatial> input_right_pads_i64;
1112
1113 array_convert(a_g_n_c_wis_lengths_i64, a_g_n_c_wis_lengths);
1114 array_convert(a_g_n_c_wis_strides_i64, a_g_n_c_wis_strides);
1115 array_convert(b_g_k_c_xs_lengths_i64, b_g_k_c_xs_lengths);
1116 array_convert(b_g_k_c_xs_strides_i64, b_g_k_c_xs_strides);
1117 for(index_t d = 0; d < NumDTensor; d++)
1118 {
1119 array_convert(ds_g_n_k_wos_lengths_i64[d], ds_g_n_k_wos_lengths[d]);
1120 array_convert(ds_g_n_k_wos_strides_i64[d], ds_g_n_k_wos_strides[d]);
1121 }
1122 array_convert(e_g_n_k_wos_lengths_i64, e_g_n_k_wos_lengths);
1123 array_convert(e_g_n_k_wos_strides_i64, e_g_n_k_wos_strides);
1124 array_convert(conv_filter_strides_i64, conv_filter_strides);
1125 array_convert(conv_filter_dilations_i64, conv_filter_dilations);
1126 array_convert(input_left_pads_i64, input_left_pads);
1127 array_convert(input_right_pads_i64, input_right_pads);
1128
1129 return std::make_unique<Argument>(p_a,
1130 p_b,
1131 p_ds,
1132 p_e,
1133 a_g_n_c_wis_lengths_i64,
1134 a_g_n_c_wis_strides_i64,
1135 b_g_k_c_xs_lengths_i64,
1136 b_g_k_c_xs_strides_i64,
1137 ds_g_n_k_wos_lengths_i64,
1138 ds_g_n_k_wos_strides_i64,
1139 e_g_n_k_wos_lengths_i64,
1140 e_g_n_k_wos_strides_i64,
1141 conv_filter_strides_i64,
1142 conv_filter_dilations_i64,
1143 input_left_pads_i64,
1144 input_right_pads_i64,
1145 a_element_op,
1146 b_element_op,
1147 cde_element_op);
1148 }
1149
1150 std::unique_ptr<BaseArgument>
1151 MakeArgumentPointer(const void* p_a,
1152 const void* p_b,
1153 const std::array<const void*, NumDTensor>& p_ds,
1154 void* p_e,
1155 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
1156 const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
1157 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
1158 const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
1159 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
1160 ds_g_n_k_wos_lengths,
1161 const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
1162 ds_g_n_k_wos_strides,
1163 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
1164 const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
1165 const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
1166 const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
1167 const std::array<long_index_t, NDimSpatial>& input_left_pads,
1168 const std::array<long_index_t, NDimSpatial>& input_right_pads,
1169 const AElementwiseOperation& a_element_op,
1170 const BElementwiseOperation& b_element_op,
1171 const CDEElementwiseOperation& cde_element_op) override
1172 {
1173
1174 return std::make_unique<Argument>(p_a,
1175 p_b,
1176 p_ds,
1177 p_e,
1178 a_g_n_c_wis_lengths,
1179 a_g_n_c_wis_strides,
1180 b_g_k_c_xs_lengths,
1181 b_g_k_c_xs_strides,
1182 ds_g_n_k_wos_lengths,
1183 ds_g_n_k_wos_strides,
1184 e_g_n_k_wos_lengths,
1185 e_g_n_k_wos_strides,
1186 conv_filter_strides,
1187 conv_filter_dilations,
1188 input_left_pads,
1189 input_right_pads,
1190 a_element_op,
1191 b_element_op,
1192 cde_element_op);
1193 }
1194
1195 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1196 {
1197 return std::make_unique<Invoker>(Invoker{});
1198 }
1199
1200 std::string GetTypeString() const override
1201 {
1202 auto str = std::stringstream();
1203
1204 // clang-format off
1205 str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor"
1206 << "<"
1207 << BlockSize << ", "
1208 << MPerBlock << ", "
1209 << NPerBlock << ", "
1210 << KPerBlock << ", "
1211 << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
1212 << MPerXDL << ", "
1213 << NPerXDL << ", "
1214 << MXdlPerWave << ", "
1215 << NXdlPerWave << ", "
1216 << ABlockTransferSrcScalarPerVector << ", "
1217 << BBlockTransferSrcScalarPerVector << ", "
1218 << CDEBlockTransferScalarPerVector_NPerBlock << ", "
1219 << CShuffleMXdlPerWavePerShuffle << ", "
1220 << CShuffleNXdlPerWavePerShuffle
1221 << ">";
1222 // clang-format on
1223
1224 return str.str();
1225 }
1226
1227#ifdef CK_EXPERIMENTAL_BUILDER
1228 std::string GetInstanceString() const override
1229 {
1230 static_assert(
1231 ck_tile::reflect::HasInstanceTraits<DeviceOp>,
1232 "Specialization of instance_traits not found. Please check that a "
1233 "specialization exists in file "
1234 "ck_tile/builder/reflect/"
1235 "instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp "
1236 "for the given template parameters.");
1237 return ck_tile::reflect::instance_string<DeviceOp>();
1238 }
1239#endif
1240};
1241
1242} // namespace device
1243} // namespace tensor_operation
1244} // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition ck.hpp:31
#define CK_MAX_THREAD_PER_BLOCK
Definition ck.hpp:30
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition tensor_operation/gpu/device/tensor_layout.hpp:42
Definition convolution_backward_data_specialization.hpp:8
constexpr bool is_NSpatialGC_GKSpatial_NSpatialGK()
Definition device_grouped_conv_utils.hpp:119
GemmSpecialization
Definition gemm_specialization.hpp:11
decltype(std::declval< T & >().IsTuple()) is_tuple
Definition device_grouped_conv_fwd_multiple_abd.hpp:23
ConvolutionForwardSpecialization
Definition convolution_forward_specialization.hpp:15
@ Filter1x1Stride1Pad0
Definition convolution_forward_specialization.hpp:18
@ Filter3x3
Definition convolution_forward_specialization.hpp:20
@ Filter1x1Pad0
Definition convolution_forward_specialization.hpp:17
std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization &s)
Definition convolution_forward_specialization.hpp:24
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
integral_constant< index_t, N > Number
Definition number.hpp:12
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition amd_wave_read_first_lane.hpp:100
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
bool is_tf32_supported()
Definition host_utility/device_prop.hpp:132
__host__ __device__ void array_convert(std::array< Y, NumElems > &y, const std::array< X, NumElems > &x)
Definition utility/type_convert.hpp:2466
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
int64_t long_index_t
Definition ck.hpp:300
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition utility/array.hpp:14
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
__host__ static __device__ constexpr auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:207
__host__ static __device__ constexpr auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:190
__host__ static __device__ constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N &e_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:224
__host__ static __device__ constexpr auto MakeDefaultBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:257
__host__ static __device__ constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N &ds_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:245
Definition multi_index_transform.hpp:196
Definition multi_index_transform.hpp:284
Definition utility/integral_constant.hpp:20
Definition functional2.hpp:33
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:25
__host__ auto SplitConvProblem(const ADataType *a_grid_ptr_base, DsPointer &ds_grid_ptr_base, CDataType *c_grid_ptr_base) const
Definition tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp:393
Definition device_base.hpp:197
virtual std::string GetInstanceString() const
Definition device_base.hpp:230
Grouped Convolution Forward.
Definition device_grouped_conv_fwd_multiple_abd.hpp:73
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:469
bool is_split_valid_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:683
index_t gemms_count_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:680
index_t valid_gemms_count_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:681
std::array< long_index_t, NDimSpatial+3 > b_g_k_c_xs_lengths_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:697
BElementwiseOperation b_element_op_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:691
ComputePtrOffsetOfStridedBatch< I1, I1, NumDTensor > compute_ptr_offset_of_groups_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:686
std::array< long_index_t, NDimSpatial+3 > e_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:701
std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_lengths_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:699
void init_gemm_args(const ADataType *a_ptr, const BDataType *b_ptr, DsPointer ds_ptr, EDataType *e_ptr, const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const DsGridDesc_M_N_ &ds_grid_desc_m_n, const EGridDescriptor_M_N_ &e_grid_desc_m_n, const Block2ETileMap &block_2_etile_map, index_t BlockStart, index_t BlockEnd)
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:471
std::array< long_index_t, NDimSpatial > conv_filter_dilations_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:704
index_t num_group_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:674
std::array< long_index_t, NDimSpatial+3 > e_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:702
Array< GemmArgs, MaxGemmsNum > gemm_desc_kernel_args_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:677
CDEElementwiseOperation cde_element_op_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:692
Argument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_dilations, const std::array< long_index_t, NDimSpatial > &input_left_pads, const std::array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:507
AElementwiseOperation a_element_op_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:690
ComputePtrOffsetOfStridedBatch< I1, I1, NumDTensor > compute_ptr_offset_of_n_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:687
std::array< long_index_t, NDimSpatial > input_right_pads_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:706
std::array< long_index_t, NDimSpatial+3 > a_g_n_c_wis_strides_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:696
std::array< long_index_t, NDimSpatial+3 > a_g_n_c_wis_lengths_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:695
void Print() const
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:659
std::array< long_index_t, NDimSpatial+3 > b_g_k_c_xs_strides_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:698
index_t conv_N_per_block_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:675
std::array< long_index_t, NDimSpatial > input_left_pads_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:705
std::array< long_index_t, NDimSpatial > conv_filter_strides_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:703
index_t grid_size_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:679
std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_strides_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:700
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:448
DsPointer ds_ptr_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:452
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:457
ck::index_t BlockEnd_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:464
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:456
const ADataType * a_ptr_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:450
const BDataType * b_ptr_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:451
Block2ETileMap block_2_etile_map_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:463
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:459
EDataType * e_ptr_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:453
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:460
ck::index_t BlockStart_
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:464
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:711
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:715
DeviceOp::Argument Argument
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:713
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:771
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:206
static auto CastDsPointers(const std::array< const void *, NumDTensor > &p_ds)
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:293
static auto MakeArgument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_dilations, const std::array< long_index_t, NDimSpatial > &input_left_pads, const std::array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:1033
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:1195
std::string GetTypeString() const override
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:1200
remove_cvref_t< decltype(GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{}))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:439
remove_cvref_t< decltype(MakeBGridDescriptor_N_K< BLayout >(dummy_conv_to_gemm_transformer))> BGridDesc_N_K
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:308
GridwiseGemmMultipleD_xdl_cshuffle< GridwiseGemmTemplateParameters > GridwiseGemmBase
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:425
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:426
static constexpr auto I2
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:220
static auto MakeBGridDescriptor_N_K(const ConvToGemmFwdTransformerIndexT &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:257
static constexpr index_t MaxGemmsNum
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:213
static bool IsSupportedArgument(const Argument &arg)
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:778
static constexpr auto I0
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:218
decltype(CastDsPointers(std::array< const void *, NumDTensor >{})) DsPointer
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:303
static constexpr bool DoElementwiseBeforeCShuffle
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:214
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> Block2ETileMap
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:444
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:956
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))> DsGridDesc_M_N
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:310
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformerIndexT &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:282
static constexpr ConvToGemmFwdTransformerIndexT dummy_conv_to_gemm_transformer
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:305
static constexpr index_t NumDTensor
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:212
static auto MakeAGridDescriptor_M_K(const ConvToGemmFwdTransformerIndexT &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:244
remove_cvref_t< decltype(MakeAGridDescriptor_M_K< ALayout >(dummy_conv_to_gemm_transformer))> AGridDesc_M_K
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:306
static auto MakeArgument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:961
remove_cvref_t< decltype(MakeEGridDescriptor_M_N< ELayout >(dummy_conv_to_gemm_transformer))> EGridDesc_M_N
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:312
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor DeviceOp
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:207
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op) override
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:1078
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformerIndexT &conv_to_gemm_transformer)
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:270
static constexpr auto NXdlPerWave32
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:210
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1( AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:430
TransformConvFwdToGemm< NDimSpatial, ConvForwardSpecialization, true, ADataType, EDataType, I1, index_t > ConvToGemmFwdTransformerIndexT
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:223
static auto GenerateConvToGemmTransforms(ConvToGemmFwdTransformerLongIndexT conv_to_gemm_transformer_base, const ADataType *a_grid_ptr_base, DsPointer ds_grid_ptr_base, EDataType *c_grid_ptr_base)
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:316
static constexpr auto matrix_padder
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:239
TransformConvFwdToGemm< NDimSpatial, ConvForwardSpecialization, true, ADataType, EDataType, I1, long_index_t > ConvToGemmFwdTransformerLongIndexT
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:231
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:427
static constexpr auto I3
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:221
static auto MakeInvoker()
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:1076
static constexpr auto I1
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:219
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:209
remove_cvref_t< decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:436
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1( BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:433
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_dilations, const std::array< long_index_t, NDimSpatial > &input_left_pads, const std::array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op) override
Definition device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp:1151
Definition matrix_padder.hpp:180
#define CK_ENV(name)
Definition utility/env.hpp:129