MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ > Struct Template Reference

MXFlatmmKernel&lt; TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ > Struct Template Reference
ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ > Struct Template Reference

#include <mx_flatmm_kernel.hpp>

Inheritance diagram for ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >:
ck_tile::FlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >

Public Types

using Underlying = FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_>
using TilePartitioner = remove_cvref_t<TilePartitioner_>
using FlatmmPipeline = remove_cvref_t<MXFlatmmPipeline_>
using BlockGemmShape
using EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>
using ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>
using BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>
using ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>
using DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>
using DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>
using ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>
using BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>
using EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>
using SplitKBatchOffset = typename Underlying::SplitKBatchOffset
Public Types inherited from ck_tile::FlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >
using TilePartitioner
using FlatmmPipeline
using BlockGemmShape
using EpiloguePipeline
using ALayout
using BLayout
using ELayout
using DsLayout
using DsDataType
using ADataType
using BDataType
using EDataType

Public Member Functions

template<class ScaleM, class ScaleN>
CK_TILE_DEVICE void operator() (FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> kargs, int partition_idx=blockIdx.x) const
Public Member Functions inherited from ck_tile::FlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >
CK_TILE_DEVICE void operator() (FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> kargs, int partition_idx=blockIdx.x) const

Static Public Member Functions

static CK_TILE_HOST const std::string GetName ()
template<class ScaleM, class ScaleN>
static CK_TILE_HOST constexpr auto GridSize (const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs)
template<memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
static CK_TILE_DEVICE auto MakeGemmTensorViews (const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
template<typename TensorView>
static CK_TILE_DEVICE auto MakeGemmPadViews (const TensorView &views)
template<typename PadView>
static CK_TILE_DEVICE auto MakeGemmTileWindows (const PadView &views, const index_t i_m, const index_t i_n)
template<class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
static CK_TILE_DEVICE void RunFlatmm (const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_ping, void *smem_ptr_pong, const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Static Public Member Functions inherited from ck_tile::FlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >
static CK_TILE_HOST const std::string GetName ()
static CK_TILE_HOST constexpr auto GridSize (index_t M, index_t N, index_t KBatch)
static CK_TILE_HOST constexpr auto GridSize (const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs)
static CK_TILE_HOST constexpr auto BlockSize ()
static CK_TILE_HOST constexpr FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> MakeKernelArgs (const ScaleFlatmmHostArgs< ScaleM, ScaleN, DsDataType::size()> &hostArgs)
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemPingSize ()
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemPongSize ()
static CK_TILE_HOST bool IsSupportedArgument (const KernelArgs &kargs)
static CK_TILE_DEVICE auto MakeGemmTensorViews (const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
static CK_TILE_DEVICE auto MakeGemmPadViews (const TensorView &views)
static CK_TILE_DEVICE auto MakeGemmTileWindows (const PadView &views, const index_t i_m, const index_t i_n)
static CK_TILE_DEVICE void RunFlatmm (const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_ping, void *smem_ptr_pong, const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)

Static Public Attributes

static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize
static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel
static constexpr int MThreadPerXdl = BlockGemmShape::WarpTile::at(number<0>{})
static constexpr int NThreadPerXdl = BlockGemmShape::WarpTile::at(number<1>{})
static constexpr int KThreadPerXdl = 64 / MThreadPerXdl
static constexpr int APackedSize = numeric_traits<ADataType>::PackedSize
static constexpr int BPackedSize = numeric_traits<BDataType>::PackedSize
static constexpr int MXdlPack = FlatmmPipeline::MXdlPack
static constexpr int NXdlPack = FlatmmPipeline::NXdlPack
static constexpr int KXdlPack = FlatmmPipeline::KXdlPack
static constexpr index_t NumDTensor = DsDataType::size()
static constexpr auto I0 = number<0>()
static constexpr auto I1 = number<1>()
static constexpr auto I2 = number<2>()
static constexpr auto I3 = number<3>()
static constexpr auto I4 = number<4>()
static constexpr auto I5 = number<5>()
Static Public Attributes inherited from ck_tile::FlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >
static constexpr index_t kBlockSize
static constexpr bool UsePersistentKernel
static constexpr index_t NumDTensor
static constexpr auto I0
static constexpr auto I1
static constexpr auto I2
static constexpr auto I3

Member Typedef Documentation

◆ ADataType

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
using ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::ADataType = remove_cvref_t<typename FlatmmPipeline::ADataType>

◆ ALayout

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
using ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::ALayout = remove_cvref_t<typename FlatmmPipeline::ALayout>

◆ BDataType

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
using ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::BDataType = remove_cvref_t<typename FlatmmPipeline::BDataType>

◆ BLayout

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
using ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::BLayout = remove_cvref_t<typename FlatmmPipeline::BLayout>

◆ BlockGemmShape

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
using ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::BlockGemmShape
Initial value:
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21

◆ DsDataType

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
using ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::DsDataType = remove_cvref_t<typename EpiloguePipeline::DsDataType>

◆ DsLayout

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
using ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::DsLayout = remove_cvref_t<typename EpiloguePipeline::DsLayout>

◆ EDataType

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
using ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::EDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>

◆ ELayout

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
using ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::ELayout = remove_cvref_t<typename FlatmmPipeline::CLayout>

◆ EpiloguePipeline

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
using ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>

◆ FlatmmPipeline

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
using ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::FlatmmPipeline = remove_cvref_t<MXFlatmmPipeline_>

◆ SplitKBatchOffset

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
using ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::SplitKBatchOffset = typename Underlying::SplitKBatchOffset

◆ TilePartitioner

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
using ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::TilePartitioner = remove_cvref_t<TilePartitioner_>

◆ Underlying

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
using ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::Underlying = FlatmmKernel<TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_>

Member Function Documentation

◆ GetName()

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST const std::string ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::GetName ( )
inlinestaticnodiscard

◆ GridSize()

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
template<class ScaleM, class ScaleN>
CK_TILE_HOST constexpr auto ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::GridSize ( const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> & kargs)
inlinestaticconstexpr

◆ MakeGemmPadViews()

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
template<typename TensorView>
CK_TILE_DEVICE auto ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::MakeGemmPadViews ( const TensorView & views)
inlinestatic

◆ MakeGemmTensorViews()

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
template<memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
CK_TILE_DEVICE auto ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::MakeGemmTensorViews ( const ADataType * a_ptr,
const BDataType * b_flat_ptr,
const std::array< const void *, NumDTensor > & ds_ptr,
EDataType * e_ptr,
const KernelArgs & kargs,
const SplitKBatchOffset & splitk_batch_offset )
inlinestatic

◆ MakeGemmTileWindows()

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
template<typename PadView>
CK_TILE_DEVICE auto ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::MakeGemmTileWindows ( const PadView & views,
const index_t i_m,
const index_t i_n )
inlinestatic

◆ operator()()

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
template<class ScaleM, class ScaleN>
CK_TILE_DEVICE void ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::operator() ( FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> kargs,
int partition_idx = blockIdx.x ) const
inline

◆ RunFlatmm()

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
template<class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
CK_TILE_DEVICE void ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::RunFlatmm ( const ADataType * a_ptr,
const BDataType * b_flat_ptr,
const std::array< const void *, NumDTensor > & ds_ptr,
EDataType * e_ptr,
void * smem_ptr_ping,
void * smem_ptr_pong,
const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> & kargs,
const SplitKBatchOffset & splitk_batch_offset,
const index_t block_idx_m,
const index_t block_idx_n )
inlinestatic

Member Data Documentation

◆ APackedSize

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
int ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::APackedSize = numeric_traits<ADataType>::PackedSize
staticconstexpr

◆ BPackedSize

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
int ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::BPackedSize = numeric_traits<BDataType>::PackedSize
staticconstexpr

◆ I0

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
auto ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::I0 = number<0>()
staticconstexpr

◆ I1

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
auto ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::I1 = number<1>()
staticconstexpr

◆ I2

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
auto ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::I2 = number<2>()
staticconstexpr

◆ I3

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
auto ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::I3 = number<3>()
staticconstexpr

◆ I4

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
auto ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::I4 = number<4>()
staticconstexpr

◆ I5

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
auto ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::I5 = number<5>()
staticconstexpr

◆ KernelBlockSize

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
index_t ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::KernelBlockSize = FlatmmPipeline::BlockSize
staticconstexpr

◆ KThreadPerXdl

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
int ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::KThreadPerXdl = 64 / MThreadPerXdl
staticconstexpr

◆ KXdlPack

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
int ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::KXdlPack = FlatmmPipeline::KXdlPack
staticconstexpr

◆ MThreadPerXdl

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
int ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::MThreadPerXdl = BlockGemmShape::WarpTile::at(number<0>{})
staticconstexpr

◆ MXdlPack

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
int ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::MXdlPack = FlatmmPipeline::MXdlPack
staticconstexpr

◆ NThreadPerXdl

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
int ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::NThreadPerXdl = BlockGemmShape::WarpTile::at(number<1>{})
staticconstexpr

◆ NumDTensor

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
index_t ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::NumDTensor = DsDataType::size()
staticconstexpr

◆ NXdlPack

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
int ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::NXdlPack = FlatmmPipeline::NXdlPack
staticconstexpr

◆ UsePersistentKernel

template<typename TilePartitioner_, typename MXFlatmmPipeline_, typename EpiloguePipeline_>
bool ck_tile::MXFlatmmKernel< TilePartitioner_, MXFlatmmPipeline_, EpiloguePipeline_ >::UsePersistentKernel = FlatmmPipeline::UsePersistentKernel
staticconstexpr

The documentation for this struct was generated from the following file: