30 typename InElementwiseOperation,
31 typename WeiElementwiseOperation,
32 typename OutElementwiseOperation,
43 typename ABlockTransferThreadClusterLengths_K0_M_K1,
44 typename ABlockTransferThreadClusterArrangeOrder,
45 typename ABlockTransferSrcAccessOrder,
49 bool ABlockLdsAddExtraM,
50 typename BBlockTransferThreadClusterLengths_K0_N_K1,
51 typename BBlockTransferThreadClusterArrangeOrder,
52 typename BBlockTransferSrcAccessOrder,
56 bool BBlockLdsAddExtraN,
62 ck::tuple_element_t<NDimSpatial - 1,
63 ck::Tuple<ck::tensor_layout::convolution::NWC,
64 ck::tensor_layout::convolution::NHWC,
65 ck::tensor_layout::convolution::NDHWC>>,
66 ck::tuple_element_t<NDimSpatial - 1,
67 ck::Tuple<ck::tensor_layout::convolution::KXC,
68 ck::tensor_layout::convolution::KYXC,
69 ck::tensor_layout::convolution::KZYXC>>,
70 ck::tuple_element_t<NDimSpatial - 1,
71 ck::Tuple<ck::tensor_layout::convolution::NWK,
72 ck::tensor_layout::convolution::NHWK,
73 ck::tensor_layout::convolution::NDHWK>>,
77 InElementwiseOperation,
78 WeiElementwiseOperation,
79 OutElementwiseOperation>
103 static_assert((K1 % ABlockTransferThreadClusterLengths_K0_M_K1{}[
I2]) %
104 ABlockTransferSrcScalarPerVector ==
106 static_assert((NPerBlock / BBlockTransferThreadClusterLengths_K0_N_K1{}[
I1]) %
107 BBlockTransferSrcScalarPerVector ==
113 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type =
false>
118 std::vector<ck::index_t> input_spatial_lengths,
119 std::vector<ck::index_t> filter_spatial_lengths,
120 std::vector<ck::index_t> output_spatial_lengths,
121 std::vector<ck::index_t> conv_filter_strides,
122 std::vector<ck::index_t> conv_filter_dilations,
123 std::vector<ck::index_t> input_left_pads,
124 std::vector<ck::index_t> input_right_pads,
125 std::vector<ck::index_t> tildes)
131 const index_t Wi = input_spatial_lengths[0];
132 const index_t Wo = output_spatial_lengths[0];
133 const index_t X = filter_spatial_lengths[0];
134 const index_t InLeftPadW = input_left_pads[0];
135 const index_t InRightPadW = input_right_pads[0];
136 const index_t ConvStrideW = conv_filter_strides[0];
137 const index_t ConvDilationW = conv_filter_dilations[0];
139 const auto K0 = K / K1;
143 if constexpr(ConvBackwardDataSpecialization ==
155 const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
172 in_n_x_wo_c_grid_desc,
179 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
180 wei_gemmk0_gemmn_gemmk1_grid_desc,
181 in_gemmm_gemmn_grid_desc);
185 const auto out_n_wo_k_grid_desc =
187 const auto wei_k_x_c_grid_desc =
190 const auto GcdStrideDilationW =
math::gcd(ConvStrideW, ConvDilationW);
192 const auto XTilde = ConvStrideW / GcdStrideDilationW;
201 math::max(
I0, InLeftPadW - ConvDilationW * (XTilde -
I1)), ConvStrideW);
206 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
213 out_n_wo_k_grid_desc,
221 out_n_wop_k_grid_desc,
231 out_n_xdot_wtilde_k_grid_desc,
240 out_n_xdotslice_wtildeslice_k0_k1_grid_desc,
258 wei_k_xdot_xtilde_c_grid_desc,
267 wei_k0_k1_xdotslice_c_grid_desc,
284 in_n_wip_c_grid_desc,
293 in_n_xtilde_wtilde_c_grid_desc,
302 in_n_wtildeslice_c_grid_desc,
308 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
309 wei_gemmk0_gemmn_gemmk1_grid_desc,
310 in_gemmm_gemmn_grid_desc);
314 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type =
false>
319 std::vector<ck::index_t> input_spatial_lengths,
320 std::vector<ck::index_t> filter_spatial_lengths,
321 std::vector<ck::index_t> output_spatial_lengths,
322 std::vector<ck::index_t> conv_filter_strides,
323 std::vector<ck::index_t> conv_filter_dilations,
324 std::vector<ck::index_t> input_left_pads,
325 std::vector<ck::index_t> input_right_pads,
326 std::vector<ck::index_t> tildes)
333 const index_t Hi = input_spatial_lengths[0];
334 const index_t Wi = input_spatial_lengths[1];
336 const index_t Ho = output_spatial_lengths[0];
337 const index_t Wo = output_spatial_lengths[1];
339 const index_t Y = filter_spatial_lengths[0];
340 const index_t X = filter_spatial_lengths[1];
342 const index_t InLeftPadH = input_left_pads[0];
343 const index_t InLeftPadW = input_left_pads[1];
345 const index_t InRightPadH = input_right_pads[0];
346 const index_t InRightPadW = input_right_pads[1];
348 const index_t ConvStrideH = conv_filter_strides[0];
349 const index_t ConvStrideW = conv_filter_strides[1];
351 const index_t ConvDilationH = conv_filter_dilations[0];
352 const index_t ConvDilationW = conv_filter_dilations[1];
354 const auto K0 = K / K1;
356 const auto out_n_ho_wo_k_grid_desc =
358 const auto wei_k_y_x_c_grid_desc =
360 const auto in_n_hi_wi_c_grid_desc =
363 if constexpr(ConvBackwardDataSpecialization ==
375 const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
384 in_n_hi_wi_c_grid_desc,
393 in_n_y_ho_x_wo_c_grid_desc,
401 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
402 wei_gemmk0_gemmn_gemmk1_grid_desc,
403 in_gemmm_gemmn_grid_desc);
407 const auto GcdStrideDilationH =
math::gcd(ConvStrideH, ConvDilationH);
408 const auto GcdStrideDilationW =
math::gcd(ConvStrideW, ConvDilationW);
410 const auto YTilde = ConvStrideH / GcdStrideDilationH;
411 const auto XTilde = ConvStrideW / GcdStrideDilationW;
423 math::max(
I0, InLeftPadH - ConvDilationH * (YTilde -
I1)), ConvStrideH);
425 math::max(
I0, InLeftPadW - ConvDilationW * (XTilde -
I1)), ConvStrideW);
432 const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
433 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
441 out_n_ho_wo_k_grid_desc,
450 out_n_hop_wop_k_grid_desc,
461 const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
463 out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
484 out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
493 wei_k_y_x_c_grid_desc,
503 const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
525 wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
534 in_n_hi_wi_c_grid_desc,
543 in_n_hip_wip_c_grid_desc,
554 in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
575 in_n_htildeslice_wtildeslice_c_grid_desc,
581 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
582 wei_gemmk0_gemmn_gemmk1_grid_desc,
583 in_gemmm_gemmn_grid_desc);
588 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type =
false>
593 std::vector<ck::index_t> input_spatial_lengths,
594 std::vector<ck::index_t> filter_spatial_lengths,
595 std::vector<ck::index_t> output_spatial_lengths,
596 std::vector<ck::index_t> conv_filter_strides,
597 std::vector<ck::index_t> conv_filter_dilations,
598 std::vector<ck::index_t> input_left_pads,
599 std::vector<ck::index_t> input_right_pads,
600 std::vector<ck::index_t> tildes)
604 const index_t i_ztilde = tildes[0];
605 const index_t i_ytilde = tildes[1];
606 const index_t i_xtilde = tildes[2];
608 const index_t Di = input_spatial_lengths[0];
609 const index_t Hi = input_spatial_lengths[1];
610 const index_t Wi = input_spatial_lengths[2];
612 const index_t Do = output_spatial_lengths[0];
613 const index_t Ho = output_spatial_lengths[1];
614 const index_t Wo = output_spatial_lengths[2];
616 const index_t Z = filter_spatial_lengths[0];
617 const index_t Y = filter_spatial_lengths[1];
618 const index_t X = filter_spatial_lengths[2];
620 const index_t InLeftPadD = input_left_pads[0];
621 const index_t InLeftPadH = input_left_pads[1];
622 const index_t InLeftPadW = input_left_pads[2];
624 const index_t InRightPadD = input_right_pads[0];
625 const index_t InRightPadH = input_right_pads[1];
626 const index_t InRightPadW = input_right_pads[2];
628 const index_t ConvStrideD = conv_filter_strides[0];
629 const index_t ConvStrideH = conv_filter_strides[1];
630 const index_t ConvStrideW = conv_filter_strides[2];
632 const index_t ConvDilationD = conv_filter_dilations[0];
633 const index_t ConvDilationH = conv_filter_dilations[1];
634 const index_t ConvDilationW = conv_filter_dilations[2];
636 const auto K0 = K / K1;
638 const auto out_n_do_ho_wo_k_grid_desc =
640 const auto wei_k_z_y_x_c_grid_desc =
642 const auto in_n_di_hi_wi_c_grid_desc =
645 if constexpr(ConvBackwardDataSpecialization ==
657 const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
666 in_n_di_hi_wi_c_grid_desc,
681 in_n_z_do_y_ho_x_wo_c_grid_desc,
694 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
695 wei_gemmk0_gemmn_gemmk1_grid_desc,
696 in_gemmm_gemmn_grid_desc);
700 const auto GcdStrideDilationD =
math::gcd(ConvStrideD, ConvDilationD);
701 const auto GcdStrideDilationH =
math::gcd(ConvStrideH, ConvDilationH);
702 const auto GcdStrideDilationW =
math::gcd(ConvStrideW, ConvDilationW);
704 const auto ZTilde = ConvStrideD / GcdStrideDilationD;
705 const auto YTilde = ConvStrideH / GcdStrideDilationH;
706 const auto XTilde = ConvStrideW / GcdStrideDilationW;
721 math::max(
I0, InLeftPadD - ConvDilationD * (ZTilde -
I1)), ConvStrideD);
723 math::max(
I0, InLeftPadH - ConvDilationH * (YTilde -
I1)), ConvStrideH);
725 math::max(
I0, InLeftPadW - ConvDilationW * (XTilde -
I1)), ConvStrideW);
734 const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
735 const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
736 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
745 out_n_do_ho_wo_k_grid_desc,
756 const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
758 out_n_dop_hop_wop_k_grid_desc,
777 out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
779 out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
806 out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
815 const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc =
817 wei_k_z_y_x_c_grid_desc,
835 const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc =
863 wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc,
872 in_n_di_hi_wi_c_grid_desc,
883 const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
885 in_n_dip_hip_wip_c_grid_desc,
902 const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
904 in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
931 in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
938 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
939 wei_gemmk0_gemmn_gemmk1_grid_desc,
940 in_gemmm_gemmn_grid_desc);
945 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type =
false>
949 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0});
952 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type =
false>
956 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0});
959 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type =
false>
982 template <index_t NXdlPerWave_>
989 InElementwiseOperation,
990 WeiElementwiseOperation,
991 OutElementwiseOperation,
1000 ABlockTransferThreadClusterLengths_K0_M_K1,
1001 ABlockTransferThreadClusterArrangeOrder,
1002 ABlockTransferSrcAccessOrder,
1003 ABlockTransferSrcVectorDim,
1004 ABlockTransferSrcScalarPerVector,
1005 ABlockTransferDstScalarPerVector_K1,
1008 BBlockTransferThreadClusterLengths_K0_N_K1,
1009 BBlockTransferThreadClusterArrangeOrder,
1010 BBlockTransferSrcAccessOrder,
1011 BBlockTransferSrcVectorDim,
1012 BBlockTransferSrcScalarPerVector,
1013 BBlockTransferDstScalarPerVector_K1,
1018 CThreadTransferDstScalarPerVector>;
1026 const WeiDataType* p_wei_grid,
1027 const OutDataType* p_out_grid,
1031 std::vector<ck::index_t> input_spatial_lengths,
1032 std::vector<ck::index_t> filter_spatial_lengths,
1033 std::vector<ck::index_t> output_spatial_lengths,
1034 std::vector<ck::index_t> conv_filter_strides,
1035 std::vector<ck::index_t> conv_filter_dilations,
1036 std::vector<ck::index_t> input_left_pads,
1037 std::vector<ck::index_t> input_right_pads)
1055 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type =
false>
1060 const auto GcdStrideDilationW =
math::gcd(ConvStrideW, ConvDilationW);
1061 const auto XTilde = ConvStrideW / GcdStrideDilationW;
1065 for(
index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
1092 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type =
false>
1101 const auto GcdStrideDilationH =
math::gcd(ConvStrideH, ConvDilationH);
1102 const auto GcdStrideDilationW =
math::gcd(ConvStrideW, ConvDilationW);
1104 const auto YTilde = ConvStrideH / GcdStrideDilationH;
1105 const auto XTilde = ConvStrideW / GcdStrideDilationW;
1109 for(
index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
1111 for(
index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
1116 if(YDotSlice * XDotSlice <= 0)
1133 {i_ytilde, i_xtilde});
1140 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type =
false>
1151 const auto GcdStrideDilationD =
math::gcd(ConvStrideD, ConvDilationD);
1152 const auto GcdStrideDilationH =
math::gcd(ConvStrideH, ConvDilationH);
1153 const auto GcdStrideDilationW =
math::gcd(ConvStrideW, ConvDilationW);
1155 const auto ZTilde = ConvStrideD / GcdStrideDilationD;
1156 const auto YTilde = ConvStrideH / GcdStrideDilationH;
1157 const auto XTilde = ConvStrideW / GcdStrideDilationW;
1162 for(
index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
1164 for(
index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
1166 for(
index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
1172 if(ZDotSlice * YDotSlice * XDotSlice <= 0)
1189 {i_ztilde, i_ytilde, i_xtilde});
1226 template <
typename Gr
idwiseGemm>
1234 std::cout <<
"arg.a_grid_desc_k0_m_k1{"
1240 std::cout <<
"arg.b_grid_desc_k0_n_k1{"
1246 std::cout <<
"arg.c_grid_desc_m_n{"
1256 throw std::runtime_error(
1257 "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
1260 const auto [gdx, gdy, gdz] =
1266 if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
1279 dim3(gdx, gdy, gdz),
1302 dim3(gdx, gdy, gdz),
1321 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
1337 if constexpr(ConvBackwardDataSpecialization ==
1341 for(
int i = 0; i < NDimSpatial; i++)
1352 if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 1 &&
1353 arg.
Conv_K_ % ABlockTransferSrcScalarPerVector == 0 &&
1354 arg.
Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
1360 if(!(arg.
Conv_C_ % CThreadTransferDstScalarPerVector == 0))
1400 const WeiDataType* p_wei_grid,
1401 const OutDataType* p_out_grid,
1405 std::vector<ck::index_t> input_spatial_lengths,
1406 std::vector<ck::index_t> filter_spatial_lengths,
1407 std::vector<ck::index_t> output_spatial_lengths,
1408 std::vector<ck::index_t> conv_filter_strides,
1409 std::vector<ck::index_t> conv_filter_dilations,
1410 std::vector<ck::index_t> input_left_pads,
1411 std::vector<ck::index_t> input_right_pads)
1419 input_spatial_lengths,
1420 filter_spatial_lengths,
1421 output_spatial_lengths,
1422 conv_filter_strides,
1423 conv_filter_dilations,
1430 std::unique_ptr<BaseArgument>
1432 const void* p_wei_grid,
1433 const void* p_out_grid,
1437 std::vector<ck::index_t> input_spatial_lengths,
1438 std::vector<ck::index_t> filter_spatial_lengths,
1439 std::vector<ck::index_t> output_spatial_lengths,
1440 std::vector<ck::index_t> conv_filter_strides,
1441 std::vector<ck::index_t> conv_filter_dilations,
1442 std::vector<ck::index_t> input_left_pads,
1443 std::vector<ck::index_t> input_right_pads,
1444 InElementwiseOperation,
1445 WeiElementwiseOperation,
1446 OutElementwiseOperation)
override
1448 return std::make_unique<Argument>(
static_cast<InDataType*
>(p_in_grid),
1449 static_cast<const WeiDataType*
>(p_wei_grid),
1450 static_cast<const OutDataType*
>(p_out_grid),
1454 input_spatial_lengths,
1455 filter_spatial_lengths,
1456 output_spatial_lengths,
1457 conv_filter_strides,
1458 conv_filter_dilations,
1465 return std::make_unique<Invoker>(
Invoker{});
1470 auto str = std::stringstream();
1473 str <<
"DeviceConvNdBwdDataNwcKxcNwk_Xdl"
1475 << BlockSize <<
", "
1476 << MPerBlock <<
", "
1477 << NPerBlock <<
", "
1478 << K0PerBlock <<
", "
1480 << MXdlPerWave <<
", "
1481 << NXdlPerWave <<
", "
1482 << ABlockTransferSrcScalarPerVector <<
", "
1483 << ABlockTransferDstScalarPerVector_K1 <<
", "
1484 << BBlockTransferSrcScalarPerVector <<
", "
1485 << BBlockTransferDstScalarPerVector_K1
1487 if constexpr(ConvBackwardDataSpecialization ==
1490 str<<
" Filter1x1Stride1Pad0";
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
#define INVOKER_RUN_IMPL
Definition device_base.hpp:94
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
Definition utility/math.hpp:154
__host__ __device__ constexpr auto integer_divide_floor(X x, Y y)
Definition utility/math.hpp:66
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
__host__ __device__ constexpr auto integer_divide_ceil(X x, Y y)
Definition utility/math.hpp:72
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
Definition convolution_backward_data_specialization.hpp:8
ConvolutionBackwardDataSpecialization
Definition convolution_backward_data_specialization.hpp:11
@ Filter1x1Stride1Pad0
Definition convolution_backward_data_specialization.hpp:13
Definition convolution_backward_data_specialization.hpp:7
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition multi_index_transform_helper.hpp:163
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_freeze_transform(const LowerIndex &low_idx)
Definition multi_index_transform_helper.hpp:151
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:19
__host__ __device__ constexpr auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition multi_index_transform_helper.hpp:48
__host__ __device__ constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition multi_index_transform_helper.hpp:55
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
__global__ void kernel_gemm_xdlops_v2r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDesc_M_N c_grid_desc_m_n)
Definition gridwise_gemm_xdlops_v2r3.hpp:34
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdlops_v2r3.hpp:142
ck::GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, ABDataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, Sequence< 2, 3, 0, 1, 7, 5, 4, 6 >, 7, CThreadTransferDstScalarPerVector >::CheckValidity __host__ static __device__ constexpr bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition gridwise_gemm_xdlops_v2r3.hpp:356
Definition utility/sequence.hpp:43
Definition device_base.hpp:197
Definition device_conv_bwd_data.hpp:25
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1024
WeiElementwiseOperation b_element_op_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1205
Argument(InDataType *p_in_grid, const WeiDataType *p_wei_grid, const OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads)
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1025
InElementwiseOperation c_element_op_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1206
std::vector< ck::index_t > conv_filter_dilations_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1216
std::vector< ck::index_t > input_left_pads_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1217
std::vector< ck::index_t > input_spatial_lengths_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1212
std::vector< BGridDesc_K0_N_K1 > b_grid_desc_k0_n_k1_container_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1202
index_t Conv_N_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1208
CDataType * p_c_grid_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1200
index_t Conv_C_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1210
std::vector< ck::index_t > input_right_pads_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1218
const BDataType * p_b_grid_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1199
OutElementwiseOperation a_element_op_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1204
std::vector< CGridDesc_M_N > c_grid_desc_m_n_container_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1203
index_t Conv_K_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1209
std::vector< AGridDesc_K0_M_K1 > a_grid_desc_k0_m_k1_container_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1201
void CreateABCDesc()
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1056
std::vector< ck::index_t > filter_spatial_lengths_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1213
std::vector< ck::index_t > output_spatial_lengths_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1214
const ADataType * p_a_grid_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1198
std::vector< ck::index_t > conv_filter_strides_
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1215
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1223
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1318
DeviceOp::Argument Argument
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1224
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1227
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:80
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_K0_N_K1
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:978
static constexpr auto I6
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:100
decltype(GetABCGridDesc< NDimSpatial >()) ABCGridDescs
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:975
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, std::vector< ck::index_t > tildes)
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:115
static auto GetABCGridDesc()
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:946
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1463
static constexpr auto I5
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:99
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:84
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:979
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1019
static auto MakeInvoker()
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1428
static constexpr auto I7
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:101
static constexpr auto GemmK1Number
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:111
InDataType ABDataType
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:92
static constexpr auto NXdlPerWave32
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:85
static constexpr auto I3
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:97
std::string GetTypeString() const override
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1468
static constexpr auto I2
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:96
InDataType CDataType
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:89
static constexpr auto I0
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:94
std::unique_ptr< BaseArgument > MakeArgumentPointer(void *p_in_grid, const void *p_wei_grid, const void *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation) override
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1431
static constexpr bool IsValidCompilationParameter()
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1325
OutDataType ADataType
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:87
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1394
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_K0_M_K1
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:977
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1020
DeviceConvNdBwdDataNwcKxcNwk_Xdl DeviceOp
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:81
static constexpr auto K1Number
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:110
static bool IsSupportedArgument(const Argument &arg)
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1331
WeiDataType BDataType
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:88
static constexpr auto I1
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:95
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, ABDataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsAddExtraN, Sequence< 2, 3, 0, 1, 7, 5, 4, 6 >, 7, CThreadTransferDstScalarPerVector > GridwiseGemmBase
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:983
static constexpr auto I4
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:98
static auto MakeArgument(InDataType *p_in_grid, const WeiDataType *p_wei_grid, const OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads)
Definition device_convnd_bwd_data_nwc_kxc_nwk_xdl.hpp:1399
#define CK_ENV(name)
Definition utility/env.hpp:129