FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ > Struct Template Reference

FmhaFwdKernel&lt; FmhaPipeline_, EpiloguePipeline_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ > Struct Template Reference
ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ > Struct Template Reference

#include <fmha_fwd_kernel.hpp>

Classes

struct  t2s
struct  t2s< float >
struct  t2s< ck_tile::fp16_t >
struct  t2s< ck_tile::bf16_t >
struct  t2s< ck_tile::fp8_t >
struct  t2s< ck_tile::bf8_t >
struct  t2s< ck_tile::fp8_t, ck_tile::bf16_t >
struct  t2s< ck_tile::fp8_t, ck_tile::fp32_t >
struct  FmhaFwdEmptyKargs
struct  FmhaFwdCommonKargs
struct  FmhaFwdLogitsSoftCapKargs
struct  FmhaFwdCommonBiasKargs
struct  FmhaFwdBatchModeBiasKargs
struct  FmhaFwdAlibiKargs
struct  FmhaFwdMaskKargs
struct  FmhaFwdFp8StaticQuantKargs
struct  FmhaFwdCommonLSEKargs
struct  FmhaFwdDropoutSeedOffset
struct  FmhaFwdCommonDropoutKargs
struct  FmhaFwdBatchModeDropoutKargs
struct  FmhaFwdSkipMinSeqlenQKargs
struct  FmhaFwdBatchModeKargs
struct  FmhaFwdGroupModeKargs
struct  BlockIndices

Public Types

using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>
using EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>
using RandValOutputDataType
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>
using ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>
using AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>
using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>

Public Member Functions

CK_TILE_DEVICE void operator() (Kargs kargs) const
CK_TILE_DEVICE void run_ (Kargs kargs) const

Static Public Member Functions

static CK_TILE_HOST std::string GetName ()
template<bool Cond = !kIsGroupMode>
static CK_TILE_HOST constexpr std::enable_if_t< Cond, KargsMakeKargsImpl (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * > > drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
template<bool Cond = !kIsGroupMode>
static CK_TILE_HOST constexpr std::enable_if_t< Cond, KargsMakeKargs (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
template<bool Cond = !kIsGroupMode>
static CK_TILE_HOST constexpr std::enable_if_t< Cond, KargsMakeKargs (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
template<bool Cond = kIsGroupMode>
static CK_TILE_HOST constexpr std::enable_if_t< Cond, KargsMakeKargsImpl (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * > > drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
template<bool Cond = kIsGroupMode>
static CK_TILE_HOST constexpr std::enable_if_t< Cond, KargsMakeKargs (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
template<bool Cond = kIsGroupMode>
static CK_TILE_HOST constexpr std::enable_if_t< Cond, KargsMakeKargs (const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
static CK_TILE_HOST constexpr auto GridSize (ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k=false)
static CK_TILE_DEVICE constexpr auto GetTileIndex (const Kargs &kargs)
static CK_TILE_HOST dim3 BlockSize ()
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize ()

Static Public Attributes

static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV
static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum
static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant
static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ
static constexpr bool kHasMask = FmhaMask::IsMasking
static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy
static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad
static constexpr bool kIsAvailable = !kUseTrLoad
static constexpr std::string_view kPipelineName = FmhaPipeline::name

Member Typedef Documentation

◆ AttentionVariant

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::AttentionVariant = ck_tile::remove_cvref_t<typename FmhaPipeline::AttentionVariant>

◆ BiasDataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>

◆ EpiloguePipeline

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::EpiloguePipeline = ck_tile::remove_cvref_t<EpiloguePipeline_>

◆ FmhaMask

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>

◆ FmhaPipeline

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>

◆ Kargs

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>

◆ KDataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>

◆ LSEDataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>

◆ ODataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::ODataType = ck_tile::remove_cvref_t<typename FmhaPipeline::ODataType>

◆ QDataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>

◆ RandValOutputDataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::RandValOutputDataType
Initial value:
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21

◆ SaccDataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>

◆ VDataType

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>

◆ VLayout

template<typename FmhaPipeline_, typename EpiloguePipeline_>
using ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>

Member Function Documentation

◆ BlockSize()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST dim3 ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::BlockSize ( )
inlinestatic

◆ GetName()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST std::string ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::GetName ( )
inlinestatic

◆ GetSmemSize()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST_DEVICE constexpr ck_tile::index_t ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::GetSmemSize ( )
inlinestaticconstexpr

◆ GetTileIndex()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_DEVICE constexpr auto ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::GetTileIndex ( const Kargs & kargs)
inlinestaticconstexpr

◆ GridSize()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_HOST constexpr auto ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::GridSize ( ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_,
ck_tile::index_t hdim_v_,
bool has_padded_seqlen_k = false )
inlinestaticconstexpr

◆ MakeKargs() [1/4]

template<typename FmhaPipeline_, typename EpiloguePipeline_>
template<bool Cond = !kIsGroupMode>
CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::MakeKargs ( const void * q_ptr,
const void * k_ptr,
const void * v_ptr,
const void * bias_ptr,
void * rand_val_ptr,
void * lse_ptr,
void * o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple< const void *, const void * > & drop_seed_offset,
const void * cu_seqlen_q_ptr = nullptr,
const void * cu_seqlen_k_ptr = nullptr )
inlinestaticconstexpr

◆ MakeKargs() [2/4]

template<typename FmhaPipeline_, typename EpiloguePipeline_>
template<bool Cond = !kIsGroupMode>
CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::MakeKargs ( const void * q_ptr,
const void * k_ptr,
const void * v_ptr,
const void * bias_ptr,
void * rand_val_ptr,
void * lse_ptr,
void * o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
const std::tuple< uint64_t, uint64_t > & drop_seed_offset,
const void * cu_seqlen_q_ptr = nullptr,
const void * cu_seqlen_k_ptr = nullptr )
inlinestaticconstexpr

◆ MakeKargs() [3/4]

template<typename FmhaPipeline_, typename EpiloguePipeline_>
template<bool Cond = kIsGroupMode>
CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::MakeKargs ( const void * q_ptr,
const void * k_ptr,
const void * v_ptr,
const void * bias_ptr,
void * rand_val_ptr,
void * lse_ptr,
void * o_ptr,
const void * seqstart_q_ptr,
const void * seqstart_k_ptr,
const void * seqlen_q_ptr,
const void * seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q,
float p_drop,
bool s_randval,
const std::tuple< const void *, const void * > & drop_seed_offset,
const void * cu_seqlen_q_ptr = nullptr,
const void * cu_seqlen_k_ptr = nullptr )
inlinestaticconstexpr

◆ MakeKargs() [4/4]

template<typename FmhaPipeline_, typename EpiloguePipeline_>
template<bool Cond = kIsGroupMode>
CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::MakeKargs ( const void * q_ptr,
const void * k_ptr,
const void * v_ptr,
const void * bias_ptr,
void * rand_val_ptr,
void * lse_ptr,
void * o_ptr,
const void * seqstart_q_ptr,
const void * seqstart_k_ptr,
const void * seqlen_q_ptr,
const void * seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q,
float p_drop,
bool s_randval,
const std::tuple< uint64_t, uint64_t > & drop_seed_offset,
const void * cu_seqlen_q_ptr = nullptr,
const void * cu_seqlen_k_ptr = nullptr )
inlinestaticconstexpr

◆ MakeKargsImpl() [1/2]

template<typename FmhaPipeline_, typename EpiloguePipeline_>
template<bool Cond = !kIsGroupMode>
CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::MakeKargsImpl ( const void * q_ptr,
const void * k_ptr,
const void * v_ptr,
const void * bias_ptr,
void * rand_val_ptr,
void * lse_ptr,
void * o_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t batch_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
float p_drop,
bool s_randval,
std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * > > drop_seed_offset,
const void * cu_seqlen_q_ptr = nullptr,
const void * cu_seqlen_k_ptr = nullptr )
inlinestaticconstexpr

◆ MakeKargsImpl() [2/2]

template<typename FmhaPipeline_, typename EpiloguePipeline_>
template<bool Cond = kIsGroupMode>
CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::MakeKargsImpl ( const void * q_ptr,
const void * k_ptr,
const void * v_ptr,
const void * bias_ptr,
void * rand_val_ptr,
void * lse_ptr,
void * o_ptr,
const void * seqstart_q_ptr,
const void * seqstart_k_ptr,
const void * seqlen_q_ptr,
const void * seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
float scale_s,
float scale_p,
float scale_o,
float logits_soft_cap,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_v,
ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o,
ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right,
ck_tile::index_t mask_type,
ck_tile::index_t min_seqlen_q,
float p_drop,
bool s_randval,
std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * > > drop_seed_offset,
const void * cu_seqlen_q_ptr = nullptr,
const void * cu_seqlen_k_ptr = nullptr )
inlinestaticconstexpr

◆ operator()()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_DEVICE void ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::operator() ( Kargs kargs) const
inline

◆ run_()

template<typename FmhaPipeline_, typename EpiloguePipeline_>
CK_TILE_DEVICE void ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::run_ ( Kargs kargs) const
inline

FIXME: Before C++20, capturing structured binding variables are not supported. Remove following copy capture of the 'i_nhead' if in C++20

FIXME: Before C++20, capturing structured binding variables are not supported. Remove following copy capture of the 'i_nhead' if in C++20

Member Data Documentation

◆ BiasEnum

template<typename FmhaPipeline_, typename EpiloguePipeline_>
auto ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::BiasEnum = FmhaPipeline::BiasEnum
staticconstexpr

◆ kBlockPerCu

template<typename FmhaPipeline_, typename EpiloguePipeline_>
ck_tile::index_t ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kBlockPerCu = FmhaPipeline::kBlockPerCu
staticconstexpr

◆ kBlockPerCuInput

template<typename FmhaPipeline_, typename EpiloguePipeline_>
ck_tile::index_t ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu
staticconstexpr

◆ kBlockSize

template<typename FmhaPipeline_, typename EpiloguePipeline_>
ck_tile::index_t ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kBlockSize = FmhaPipeline::kBlockSize
staticconstexpr

◆ kDoFp8StaticQuant

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant
staticconstexpr

◆ kHasDropout

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kHasDropout = FmhaPipeline::kHasDropout
staticconstexpr

◆ kHasLogitsSoftCap

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap
staticconstexpr

◆ kHasMask

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kHasMask = FmhaMask::IsMasking
staticconstexpr

◆ kIsAvailable

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kIsAvailable = !kUseTrLoad
staticconstexpr

◆ kIsGroupMode

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kIsGroupMode = FmhaPipeline::kIsGroupMode
staticconstexpr

◆ kPadHeadDimQ

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ
staticconstexpr

◆ kPadHeadDimV

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kPadHeadDimV = FmhaPipeline::kPadHeadDimV
staticconstexpr

◆ kPadSeqLenK

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kPadSeqLenK = FmhaPipeline::kPadSeqLenK
staticconstexpr

◆ kPadSeqLenQ

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ
staticconstexpr

◆ kPipelineName

template<typename FmhaPipeline_, typename EpiloguePipeline_>
std::string_view ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kPipelineName = FmhaPipeline::name
staticconstexpr

◆ kSkipMinSeqlenQ

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ
staticconstexpr

◆ kStoreLSE

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kStoreLSE = FmhaPipeline::kStoreLSE
staticconstexpr

◆ kUseAsyncCopy

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy
staticconstexpr

◆ kUseTrLoad

template<typename FmhaPipeline_, typename EpiloguePipeline_>
bool ck_tile::FmhaFwdKernel< FmhaPipeline_, EpiloguePipeline_ >::kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad
staticconstexpr

The documentation for this struct was generated from the following file: