grouped_flatmm_kernel.hpp Source File#
grouped_flatmm_kernel.hpp
Go to the documentation of this file.
202 using UnderlyingGemmKernel = FlatmmKernel<TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_>;
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
Definition grouped_flatmm_kernel.hpp:79
const std::array< index_t, NumDTensor > stride_Ds
Definition grouped_flatmm_kernel.hpp:124
const std::array< const void *, NumDTensor > ds_ptr
Definition grouped_flatmm_kernel.hpp:123
const void * b_shuffle_ptr
Definition grouped_flatmm_kernel.hpp:121
CK_TILE_HOST ContiguousGroupedFlatmmHostArgs(index_t *M_indices_, index_t M_, index_t N_, index_t K_, const void *a_ptr_, index_t stride_A_, const void *b_shuffle_ptr_, index_t stride_B_, const std::array< const void *, NumDTensor > &ds_ptr_, const std::array< index_t, NumDTensor > &stride_Ds_, void *c_ptr_, index_t stride_C_, index_t k_batch_, ScaleM scale_m_=nullptr, ScaleN scale_n_=nullptr)
Definition grouped_flatmm_kernel.hpp:81
const void * a_ptr
Definition grouped_flatmm_kernel.hpp:119
index_t k_batch
Definition grouped_flatmm_kernel.hpp:131
ScaleM scale_m
Definition grouped_flatmm_kernel.hpp:132
CK_TILE_HOST ContiguousGroupedFlatmmHostArgs()=default
index_t group_count
Definition grouped_flatmm_kernel.hpp:114
index_t stride_B
Definition grouped_flatmm_kernel.hpp:122
index_t stride_A
Definition grouped_flatmm_kernel.hpp:120
ScaleN scale_n
Definition grouped_flatmm_kernel.hpp:133
index_t stride_C
Definition grouped_flatmm_kernel.hpp:130
index_t * M_indices
Definition grouped_flatmm_kernel.hpp:115
Definition flatmm_kernel.hpp:229
Definition flatmm_kernel.hpp:249
static CK_TILE_HOST constexpr auto BlockSize()
Definition flatmm_kernel.hpp:330
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition flatmm_kernel.hpp:252
Definition flatmm_kernel.hpp:33
Definition grouped_flatmm_kernel.hpp:19
const void ** b_shuffle_ptr
Definition grouped_flatmm_kernel.hpp:60
CK_TILE_HOST GroupedFlatmmHostArgs(index_t group_count_, index_t *M_, index_t *N_, index_t *K_, const void **a_ptr_, index_t *stride_A_, const void **b_shuffle_ptr_, index_t *stride_B_, const std::array< const void *, NumDTensor > &ds_ptr_, const std::array< index_t, NumDTensor > &stride_Ds_, void **c_ptr_, index_t *stride_C_, index_t k_batch_, ScaleM *scale_m_=nullptr, ScaleN *scale_n_=nullptr)
Definition grouped_flatmm_kernel.hpp:21
const std::array< index_t, NumDTensor > stride_Ds
Definition grouped_flatmm_kernel.hpp:63
const std::array< const void *, NumDTensor > ds_ptr
Definition grouped_flatmm_kernel.hpp:62
CK_TILE_HOST GroupedFlatmmHostArgs()=default
index_t group_count
Definition grouped_flatmm_kernel.hpp:54
Definition grouped_flatmm_kernel.hpp:201
static constexpr index_t NumDTensor
Definition grouped_flatmm_kernel.hpp:217
static CK_TILE_HOST_DEVICE auto GridSize(const ContiguousGroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > &kernelArgs)
Definition grouped_flatmm_kernel.hpp:269
static constexpr index_t kBlockSize
Definition grouped_flatmm_kernel.hpp:218
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition grouped_flatmm_kernel.hpp:205
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition grouped_flatmm_kernel.hpp:206
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition grouped_flatmm_kernel.hpp:210
FlatmmKernel< TilePartitioner_, FlatmmPipeline_, EpiloguePipeline_ > UnderlyingGemmKernel
Definition grouped_flatmm_kernel.hpp:202
static CK_TILE_HOST const std::string GetName()
Definition grouped_flatmm_kernel.hpp:228
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition grouped_flatmm_kernel.hpp:213
static CK_TILE_HOST constexpr auto MakeKernelArgs(const HostArgs &hostArgs)
Definition grouped_flatmm_kernel.hpp:336
CK_TILE_DEVICE void operator()(ContiguousGroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > kargs) const
Definition grouped_flatmm_kernel.hpp:398
CK_TILE_DEVICE void operator()(GroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > kargs) const
Definition grouped_flatmm_kernel.hpp:354
static CK_TILE_HOST_DEVICE auto GridSize(const MaskedGroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > &kernelArgs)
Definition grouped_flatmm_kernel.hpp:304
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition grouped_flatmm_kernel.hpp:214
static CK_TILE_HOST_DEVICE auto GridSize(const GroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > &kernelArgs)
Definition grouped_flatmm_kernel.hpp:238
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition grouped_flatmm_kernel.hpp:215
CK_TILE_DEVICE void operator()(MaskedGroupedFlatmmHostArgs< ScaleM, ScaleN, NumDTensor > kargs) const
Definition grouped_flatmm_kernel.hpp:436
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition grouped_flatmm_kernel.hpp:208
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition grouped_flatmm_kernel.hpp:211
typename UnderlyingGemmKernel::BlockGemmShape BlockGemmShape
Definition grouped_flatmm_kernel.hpp:203
Definition grouped_flatmm_kernel.hpp:140
index_t group_count
Definition grouped_flatmm_kernel.hpp:178
CK_TILE_HOST MaskedGroupedFlatmmHostArgs(index_t *M_indices_, index_t group_count_, index_t Max_M_, index_t N_, index_t K_, const void *a_ptr_, index_t stride_A_, const void *b_shuffle_ptr_, index_t stride_B_, const std::array< const void *, NumDTensor > &ds_ptr_, const std::array< index_t, NumDTensor > &stride_Ds_, void *c_ptr_, index_t stride_C_, index_t k_batch_, ScaleM scale_m_=nullptr, ScaleN scale_n_=nullptr)
Definition grouped_flatmm_kernel.hpp:142
CK_TILE_HOST MaskedGroupedFlatmmHostArgs()=default
index_t * M_indices
Definition grouped_flatmm_kernel.hpp:177
const std::array< const void *, NumDTensor > ds_ptr
Definition grouped_flatmm_kernel.hpp:186
index_t k_batch
Definition grouped_flatmm_kernel.hpp:194
const void * b_shuffle_ptr
Definition grouped_flatmm_kernel.hpp:184
index_t stride_C
Definition grouped_flatmm_kernel.hpp:193
const void * a_ptr
Definition grouped_flatmm_kernel.hpp:182
index_t stride_A
Definition grouped_flatmm_kernel.hpp:183
const std::array< index_t, NumDTensor > stride_Ds
Definition grouped_flatmm_kernel.hpp:187
index_t stride_B
Definition grouped_flatmm_kernel.hpp:185