FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ > Struct Template Reference

FmhaBwdConvertQGradKernel&lt; FmhaBwdConvertQGrad_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ > Struct Template Reference
ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ > Struct Template Reference

#include <fmha_bwd_kernel.hpp>

Classes

struct  t2s
struct  t2s< float >
struct  t2s< ck_tile::fp16_t >
struct  t2s< ck_tile::bf16_t >
struct  FmhaBwdConvertQGradEmptyKargs
struct  FmhaBwdConvertQGradCommonKargs
struct  FmhaBwdConvertQGradDeterministicKargs
struct  FmhaBwdConvertQGradBatchModeKargs
struct  FmhaBwdConvertQGradGroupModeKargs

Public Types

using FmhaBwdConvertQGrad = ck_tile::remove_cvref_t<FmhaBwdConvertQGrad_>
using AccDataType = ck_tile::remove_cvref_t<typename FmhaBwdConvertQGrad::AccDataType>
using QGradDataType = ck_tile::remove_cvref_t<typename FmhaBwdConvertQGrad::QGradDataType>
using Kargs

Public Member Functions

CK_TILE_DEVICE void operator() (Kargs kargs) const

Static Public Member Functions

static CK_TILE_HOST std::string GetName ()
template<bool Cond = !kIsGroupMode>
static CK_TILE_HOST constexpr std::enable_if_t< Cond, KargsMakeKargs (const void *dq_acc_ptr, void *dq_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t stride_dq, ck_tile::index_t stride_dq_acc, ck_tile::index_t nhead_stride_dq, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t batch_stride_dq, ck_tile::index_t batch_stride_dq_acc, ck_tile::index_t split_stride_dq_acc)
template<bool Cond = kIsGroupMode>
static CK_TILE_HOST constexpr std::enable_if_t< Cond, KargsMakeKargs (const void *dq_acc_ptr, void *dq_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, const void *cu_seqlen_q_ptr, const void *cu_seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t stride_dq, ck_tile::index_t stride_dq_acc, ck_tile::index_t nhead_stride_dq, ck_tile::index_t nhead_stride_dq_acc, ck_tile::index_t split_stride_dq_acc)
static CK_TILE_HOST constexpr auto GridSize (ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
static CK_TILE_DEVICE constexpr auto GetTileIndex ()
static CK_TILE_HOST dim3 BlockSize ()
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize ()

Static Public Attributes

static constexpr ck_tile::index_t kBlockSize = FmhaBwdConvertQGrad::kBlockSize
static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdConvertQGrad::kBlockPerCu
static constexpr ck_tile::index_t kM0 = FmhaBwdConvertQGrad::kM0
static constexpr ck_tile::index_t kN0 = FmhaBwdConvertQGrad::kN0
static constexpr ck_tile::index_t kQKHeaddim = FmhaBwdConvertQGrad::kQKHeaddim
static constexpr bool kIsGroupMode = FmhaBwdConvertQGrad::kIsGroupMode
static constexpr bool kPadSeqLenQ = FmhaBwdConvertQGrad::kPadSeqLenQ
static constexpr bool kPadHeadDimQ = FmhaBwdConvertQGrad::kPadHeadDimQ
static constexpr bool kIsDeterministic = FmhaBwdConvertQGrad::kIsDeterministic

Member Typedef Documentation

◆ AccDataType

template<typename FmhaBwdConvertQGrad_>
using ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ >::AccDataType = ck_tile::remove_cvref_t<typename FmhaBwdConvertQGrad::AccDataType>

◆ FmhaBwdConvertQGrad

template<typename FmhaBwdConvertQGrad_>
using ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ >::FmhaBwdConvertQGrad = ck_tile::remove_cvref_t<FmhaBwdConvertQGrad_>

◆ Kargs

template<typename FmhaBwdConvertQGrad_>
using ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ >::Kargs
Initial value:

◆ QGradDataType

template<typename FmhaBwdConvertQGrad_>
using ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ >::QGradDataType = ck_tile::remove_cvref_t<typename FmhaBwdConvertQGrad::QGradDataType>

Member Function Documentation

◆ BlockSize()

template<typename FmhaBwdConvertQGrad_>
CK_TILE_HOST dim3 ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ >::BlockSize ( )
inlinestatic

◆ GetName()

template<typename FmhaBwdConvertQGrad_>
CK_TILE_HOST std::string ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ >::GetName ( )
inlinestatic

◆ GetSmemSize()

template<typename FmhaBwdConvertQGrad_>
CK_TILE_HOST_DEVICE constexpr ck_tile::index_t ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ >::GetSmemSize ( )
inlinestaticconstexpr

◆ GetTileIndex()

template<typename FmhaBwdConvertQGrad_>
CK_TILE_DEVICE constexpr auto ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ >::GetTileIndex ( )
inlinestaticconstexpr

◆ GridSize()

template<typename FmhaBwdConvertQGrad_>
CK_TILE_HOST constexpr auto ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ >::GridSize ( ck_tile::index_t batch_size_,
ck_tile::index_t nhead_,
ck_tile::index_t seqlen_q_ )
inlinestaticconstexpr

◆ MakeKargs() [1/2]

template<typename FmhaBwdConvertQGrad_>
template<bool Cond = !kIsGroupMode>
CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ >::MakeKargs ( const void * dq_acc_ptr,
void * dq_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t stride_dq,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t nhead_stride_dq,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t batch_stride_dq,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t split_stride_dq_acc )
inlinestaticconstexpr

◆ MakeKargs() [2/2]

template<typename FmhaBwdConvertQGrad_>
template<bool Cond = kIsGroupMode>
CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ >::MakeKargs ( const void * dq_acc_ptr,
void * dq_ptr,
const void * seqstart_q_ptr,
const void * seqstart_k_ptr,
const void * seqlen_q_ptr,
const void * seqlen_k_ptr,
const void * cu_seqlen_q_ptr,
const void * cu_seqlen_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t stride_dq,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t nhead_stride_dq,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t split_stride_dq_acc )
inlinestaticconstexpr

◆ operator()()

template<typename FmhaBwdConvertQGrad_>
CK_TILE_DEVICE void ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ >::operator() ( Kargs kargs) const
inline

Member Data Documentation

◆ kBlockPerCu

template<typename FmhaBwdConvertQGrad_>
ck_tile::index_t ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ >::kBlockPerCu = FmhaBwdConvertQGrad::kBlockPerCu
staticconstexpr

◆ kBlockSize

template<typename FmhaBwdConvertQGrad_>
ck_tile::index_t ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ >::kBlockSize = FmhaBwdConvertQGrad::kBlockSize
staticconstexpr

◆ kIsDeterministic

template<typename FmhaBwdConvertQGrad_>
bool ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ >::kIsDeterministic = FmhaBwdConvertQGrad::kIsDeterministic
staticconstexpr

◆ kIsGroupMode

template<typename FmhaBwdConvertQGrad_>
bool ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ >::kIsGroupMode = FmhaBwdConvertQGrad::kIsGroupMode
staticconstexpr

◆ kM0

template<typename FmhaBwdConvertQGrad_>
ck_tile::index_t ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ >::kM0 = FmhaBwdConvertQGrad::kM0
staticconstexpr

◆ kN0

template<typename FmhaBwdConvertQGrad_>
ck_tile::index_t ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ >::kN0 = FmhaBwdConvertQGrad::kN0
staticconstexpr

◆ kPadHeadDimQ

template<typename FmhaBwdConvertQGrad_>
bool ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ >::kPadHeadDimQ = FmhaBwdConvertQGrad::kPadHeadDimQ
staticconstexpr

◆ kPadSeqLenQ

template<typename FmhaBwdConvertQGrad_>
bool ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ >::kPadSeqLenQ = FmhaBwdConvertQGrad::kPadSeqLenQ
staticconstexpr

◆ kQKHeaddim

template<typename FmhaBwdConvertQGrad_>
ck_tile::index_t ck_tile::FmhaBwdConvertQGradKernel< FmhaBwdConvertQGrad_ >::kQKHeaddim = FmhaBwdConvertQGrad::kQKHeaddim
staticconstexpr

The documentation for this struct was generated from the following file: