129template <
typename AsLayout,
135 typename AccDataType,
136 typename CShuffleDataType,
139 typename AElementwiseOperation,
140 typename BElementwiseOperation,
141 typename CDEElementwiseOperation,
153 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
154 typename ABlockTransferThreadClusterArrangeOrder,
155 typename ABlockTransferSrcAccessOrder,
156 index_t ABlockTransferSrcVectorDim,
157 index_t ABlockTransferSrcScalarPerVector,
158 index_t ABlockTransferDstScalarPerVector_AK1,
159 bool ABlockLdsExtraM,
160 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
161 typename BBlockTransferThreadClusterArrangeOrder,
162 typename BBlockTransferSrcAccessOrder,
163 index_t BBlockTransferSrcVectorDim,
164 index_t BBlockTransferSrcScalarPerVector,
165 index_t BBlockTransferDstScalarPerVector_BK1,
166 bool BBlockLdsExtraN,
167 index_t CShuffleMRepeatPerShuffle,
168 index_t CShuffleNRepeatPerShuffle,
169 typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
170 typename CDEShuffleBlockTransferScalarPerVectors,
173 typename ComputeTypeA = EDataType,
174 typename ComputeTypeB = ComputeTypeA,
175 bool PermuteA =
false,
176 bool PermuteB =
false>
186 AElementwiseOperation,
187 BElementwiseOperation,
188 CDEElementwiseOperation>
206 AElementwiseOperation,
207 BElementwiseOperation,
208 CDEElementwiseOperation,
220 ABlockTransferThreadClusterLengths_AK0_M_AK1,
221 ABlockTransferThreadClusterArrangeOrder,
222 ABlockTransferSrcAccessOrder,
223 ABlockTransferSrcVectorDim,
224 ABlockTransferSrcScalarPerVector,
225 ABlockTransferDstScalarPerVector_AK1,
228 BBlockTransferThreadClusterLengths_BK0_N_BK1,
229 BBlockTransferThreadClusterArrangeOrder,
230 BBlockTransferSrcAccessOrder,
231 BBlockTransferSrcVectorDim,
232 BBlockTransferSrcScalarPerVector,
233 BBlockTransferDstScalarPerVector_BK1,
236 CShuffleMRepeatPerShuffle,
237 CShuffleNRepeatPerShuffle,
238 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
239 CDEShuffleBlockTransferScalarPerVectors,
262 CDEShuffleBlockTransferScalarPerVectors,
282 static auto MakeArgument(std::array<const void*, GridwiseGemm::NumATensor> p_as,
283 std::array<const void*, GridwiseGemm::NumBTensor> p_bs,
284 std::array<const void*, GridwiseGemm::NumDTensor> p_ds,
289 std::array<ck::index_t, GridwiseGemm::NumATensor> StrideAs,
290 std::array<ck::index_t, GridwiseGemm::NumBTensor> StrideBs,
291 std::array<index_t, GridwiseGemm::NumDTensor> StrideDs,
294 AElementwiseOperation a_element_op,
295 BElementwiseOperation b_element_op,
296 CDEElementwiseOperation cde_element_op)
301 static_cast<EDataType*
>(p_e),
318 std::unique_ptr<BaseArgument>
320 std::array<const void*, GridwiseGemm::NumBTensor> p_bs,
321 std::array<const void*, GridwiseGemm::NumDTensor> p_ds,
326 std::array<ck::index_t, GridwiseGemm::NumATensor> StrideAs,
327 std::array<ck::index_t, GridwiseGemm::NumBTensor> StrideBs,
328 std::array<ck::index_t, GridwiseGemm::NumDTensor> StrideDs,
331 AElementwiseOperation a_element_op,
332 BElementwiseOperation b_element_op,
333 CDEElementwiseOperation cde_element_op)
override
335 return std::make_unique<Argument>(p_as,
338 static_cast<EDataType*
>(p_e),
355 return std::make_unique<Invoker>(
Invoker{});
361 auto str = std::stringstream();
363 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
367 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
375 str <<
"DeviceGemmMultipleABD_Wmma_CShuffleV3"
381 str << std::string(ALayout_::name)[0];
386 str << std::string(BLayout_::name)[0];
391 str << std::string(DLayout::name)[0];
393 str << std::string(ELayout::name)[0]
398 << MPerBlock <<
"x" << NPerBlock <<
"x" << KPerBlock <<
", "
400 << MPerWmma <<
"x"<<NPerWmma <<
", "
402 << MRepeat <<
"x" << NRepeat <<
", "
404 << ABlockTransferSrcScalarPerVector <<
"x" << BBlockTransferSrcScalarPerVector <<
", "
405 <<
"BlkGemmPipelineScheduler: "
406 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] <<
", "
407 <<
"BlkGemmPipelineVersion: "
408 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] <<
", "
409 <<
"BlkGemmPipelinePrefetchStages: "
410 << GridwiseGemm::BlockwiseGemmPipe::PrefetchStages <<
", "
#define REGISTER_EXTRA_PRINTING_METHODS
Definition device_base.hpp:47
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
int32_t index_t
Definition ck.hpp:299
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
ck::GridwiseGemm_wmma_cshuffle_v3_base< ALayout, BLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1Value, BK1Value, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, AThreadTransferSrcResetCoordinateAfterRun, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB, false >::KPack static constexpr index_t KPack
Definition gridwise_gemm_wmma_cshuffle_v3_common.hpp:154
"Universal" GEMM kernel with SplitK support.
Definition gridwise_gemm_wmma_cshuffle_v3.hpp:233
Definition functional2.hpp:33
Definition device_base.hpp:197
Helper structure responsible for kernel invocation.
Definition device_gemm_wmma_cshuffle_v3_common.hpp:57
Definition device_gemm_wmma_cshuffle_v3_common.hpp:43
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_wmma_cshuffle_v3_common.hpp:268
"Universal" GEMM operation with SplitK support and multiple D tensors.
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:189
remove_cvref_t< tuple_element_t< 0, BsLayout > > BLayout
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:193
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:271
static auto MakeArgument(std::array< const void *, GridwiseGemm::NumATensor > p_as, std::array< const void *, GridwiseGemm::NumBTensor > p_bs, std::array< const void *, GridwiseGemm::NumDTensor > p_ds, void *p_e, index_t M, index_t N, index_t K, std::array< ck::index_t, GridwiseGemm::NumATensor > StrideAs, std::array< ck::index_t, GridwiseGemm::NumBTensor > StrideBs, std::array< index_t, GridwiseGemm::NumDTensor > StrideDs, index_t StrideE, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:282
static auto MakeInvoker()
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:315
DeviceGemm_Wmma_CShuffleV3_Common< GridwiseGemm, AsDataType, BsDataType, DsDataType, EDataType, MPerBlock, NPerBlock, KPerBlock, BlockSize, AK1, BK1, GemmSpec, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > DeviceGemmCommon
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:249
remove_cvref_t< tuple_element_t< 0, AsLayout > > ALayout
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:192
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:353
typename GridwiseGemm::Argument Argument
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:247
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::array< const void *, GridwiseGemm::NumATensor > p_as, std::array< const void *, GridwiseGemm::NumBTensor > p_bs, std::array< const void *, GridwiseGemm::NumDTensor > p_ds, void *p_e, index_t M, index_t N, index_t K, std::array< ck::index_t, GridwiseGemm::NumATensor > StrideAs, std::array< ck::index_t, GridwiseGemm::NumBTensor > StrideBs, std::array< ck::index_t, GridwiseGemm::NumDTensor > StrideDs, index_t StrideE, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:319
GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, DsLayout, ELayout, AsDataType, BsDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerWmma, NPerWmma, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemm
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:195
std::string GetTypeString() const override
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:359
typename DeviceGemmCommon::Invoker Invoker
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:269
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_abd_wmma_cshuffle_v3.hpp:277
Definition device_gemm_multiple_abd.hpp:78