device_gemm_multiple_d_xdl_cshuffle.hpp Source File#
device_gemm_multiple_d_xdl_cshuffle.hpp
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
__global__ void kernel_gemm_multiple_d_xdl_cshuffle(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:42
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
__host__ __device__ constexpr auto transform_tuples(F f, const X &x)
Definition tuple_helper.hpp:98
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition utility/array.hpp:14
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
__host__ static __device__ constexpr auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:207
__host__ static __device__ constexpr auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:190
__host__ static __device__ constexpr bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const DsGridDesc_M_N &ds_grid_desc_m_n, const EGridDesc_M_N &e_grid_desc_m_n, const Block2ETileMap &, index_t k_batch=1)
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:334
__host__ static __device__ constexpr auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N &e_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:224
__host__ static __device__ constexpr auto MakeDefaultBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:257
__host__ static __device__ constexpr auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N &ds_grid_desc_m_n)
Definition gridwise_gemm_multiple_d_xdl_cshuffle.hpp:245
Definition utility/integral_constant.hpp:20
Definition functional2.hpp:33
Definition tensor_operation/gpu/device/tensor_layout.hpp:31
Definition tensor_operation/gpu/device/tensor_layout.hpp:26
Definition device_base.hpp:197
BaseArgument()=default
BaseInvoker()=default
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:325
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:398
index_t MRaw_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:409
const BDataType * p_b_grid_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:386
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:397
EGridDesc_M_N e_grid_desc_m_n_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:394
BElementwiseOperation b_element_op_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:405
index_t KRaw_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:411
index_t NRaw_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:410
EDataType * p_e_grid_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:388
AGridDesc_M_K a_grid_desc_m_k_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:391
Block2ETileMap block_2_etile_map_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:401
AElementwiseOperation a_element_op_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:404
void Print() const
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:374
DsGridDesc_M_N ds_grid_desc_m_n_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:393
const ADataType * p_a_grid_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:385
CDEElementwiseOperation cde_element_op_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:406
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:326
BGridDesc_N_K b_grid_desc_n_k_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:392
GridwiseGemm64::DsGridPointer p_ds_grid_
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:387
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:779
index_t NRaw
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:793
index_t MRaw
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:792
AElementwiseOperation a_element_op
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:787
constexpr index_t GetGridSize() const
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:870
DsGridDesc_M_N ds_grid_desc_m_n
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:774
remove_cvref_t< decltype(GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:765
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:778
constexpr bool IsValid() const
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:837
remove_cvref_t< decltype(DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{}))> EGridDesc_M_N
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:754
EGridDesc_M_N e_grid_desc_m_n
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:775
Block2ETileMap block_2_etile_map
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:784
constexpr Descriptor(ADesc a, BDesc b, DsDesc ds, EDesc e, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CDEElementwiseOperation cde_element_op_)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:798
AGridDesc_M_K a_grid_desc_m_k
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:772
remove_cvref_t< decltype(ds_tuple())> DsGridDesc_M_N
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:753
remove_cvref_t< decltype(DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{}))> AGridDesc_M_K
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:749
constexpr index_t GetBlockSize() const
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:868
index_t KRaw
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:794
remove_cvref_t< decltype(DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{}))> BGridDesc_N_K
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:751
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1( DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{})))> AGridDesc_AK0_M_AK1
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:756
remove_cvref_t< decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( ds_tuple()))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:762
bool has_main_k_block_loop
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:796
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBlock2ETileMap( DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))> Block2ETileMap
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:768
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:780
BElementwiseOperation b_element_op
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:788
CDEElementwiseOperation cde_element_op
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:789
static constexpr auto ds_tuple()
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:743
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1( DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{})))> BGridDesc_BK0_N_BK1
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:759
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:781
BGridDesc_N_K b_grid_desc_n_k
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:773
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:416
DeviceOp::Argument Argument
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:417
float RunImp(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:420
INVOKER_RUN_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:494
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:165
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBGridDescriptor_BK0_N_BK1( BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:308
GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer > GridwiseGemmBase
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:257
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:693
static constexpr auto I1
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:174
std::string GetTypeString() const override
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:699
remove_cvref_t< decltype(GridwiseGemm64::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( EGridDesc_M_N{}))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:314
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({}, {}, {}))> DsGridDesc_M_N
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:252
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:168
static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:503
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:251
DeviceGemmMultipleD_Xdl_CShuffle DeviceOp
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:166
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:301
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:626
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:302
static constexpr auto make_descriptor(ADesc a, BDesc b, DsDesc ds, EDesc e, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CDEElementwiseOperation cde_element_op=CDEElementwiseOperation{})
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:878
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultAGridDescriptor_AK0_M_AK1( AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:305
static constexpr auto I0
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:173
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:250
static constexpr auto matrix_padder
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:178
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:253
static __device__ void Run(const Desc &desc, const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:891
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:171
static constexpr auto NXdlPerWave32
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:169
static auto MakeInvoker()
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:657
remove_cvref_t< decltype(GridwiseGemm64::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( DsGridDesc_M_N{}))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:311
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:661
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:621
static auto MakeDsGridDescriptor_M_N(const Array< index_t, NumDTensor > &MRaws, const Array< index_t, NumDTensor > &NRaws, const Array< index_t, NumDTensor > &DsStride)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:236
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:584
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:199
static constexpr auto I3
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:176
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:218
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:181
remove_cvref_t< decltype(GridwiseGemm64::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> Block2ETileMap
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:319
static constexpr auto I2
Definition device_gemm_multiple_d_xdl_cshuffle.hpp:175
Definition device_gemm_multiple_d.hpp:36
Definition matrix_padder.hpp:180