BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ > Struct Template Reference#
Public Types |
Public Member Functions |
Static Public Member Functions |
Static Public Attributes |
List of all members
ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ > Struct Template Reference
#include <block_fmha_pipeline_qr_ks_vs_async.hpp>
Public Types | |
| using | Problem = remove_cvref_t<Problem_> |
| using | Policy = remove_cvref_t<Policy_> |
| using | QDataType = remove_cvref_t<typename Problem::QDataType> |
| using | KDataType = remove_cvref_t<typename Problem::KDataType> |
| using | VDataType = remove_cvref_t<typename Problem::VDataType> |
| using | SaccDataType = remove_cvref_t<typename Problem::SaccDataType> |
| using | SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType> |
| using | BiasDataType = remove_cvref_t<typename Problem::BiasDataType> |
| using | RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType> |
| using | LSEDataType = remove_cvref_t<typename Problem::LSEDataType> |
| using | PDataType = remove_cvref_t<typename Problem::PDataType> |
| using | OaccDataType = remove_cvref_t<typename Problem::OaccDataType> |
| using | ODataType = remove_cvref_t<typename Problem::ODataType> |
| using | AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant> |
| using | FmhaMask = remove_cvref_t<typename Problem::FmhaMask> |
| using | BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape> |
| using | VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout> |
| using | DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout> |
Public Member Functions | |
| template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, typename VElementFunction, typename BiasElementFunction, typename LSEElementFunction, typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, typename PositionEncoding, typename AttentionVariantParams, typename BlockIndices> | |
| CK_TILE_HOST_DEVICE auto | operator() (const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, const KElementFunction &, const VDramBlockWindowTmp &v_dram_block_window_tmp, const VElementFunction &v_element_func, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const BiasElementFunction &bias_element_func, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, DropoutType &dropout) const |
| template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename PositionEncoding, typename AttentionVariantParams, typename BlockIndices> | |
| CK_TILE_HOST_DEVICE auto | operator() (const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, DropoutType &dropout) const |
Static Public Member Functions | |
| static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t | GetSmemSize () |
Static Public Attributes | |
| static constexpr bool | kQLoadOnce = true |
| static constexpr index_t | kBlockSize = Problem::kBlockSize |
| static constexpr index_t | kM0 = BlockFmhaShape::kM0 |
| static constexpr index_t | kN0 = BlockFmhaShape::kN0 |
| static constexpr index_t | kK0 = BlockFmhaShape::kK0 |
| static constexpr index_t | kN1 = BlockFmhaShape::kN1 |
| static constexpr index_t | kK1 = BlockFmhaShape::kK1 |
| static constexpr index_t | kQKHeaddim = BlockFmhaShape::kQKHeaddim |
| static constexpr index_t | kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim |
| static constexpr bool | kIsGroupMode = Problem::kIsGroupMode |
| static constexpr bool | kPadSeqLenQ = true |
| static constexpr bool | kPadSeqLenK = Problem::kPadSeqLenK |
| static constexpr bool | kPadHeadDimQ = true |
| static constexpr bool | kPadHeadDimV = true |
| static constexpr bool | kHasLogitsSoftCap = Problem::kHasLogitsSoftCap |
| static constexpr auto | BiasEnum = Problem::BiasEnum |
| static constexpr bool | kStoreLSE = Problem::kStoreLSE |
| static constexpr bool | kHasDropout = Problem::kHasDropout |
| static constexpr index_t | kAlignmentQ = Policy::template GetAlignmentQ<Problem>() |
| static constexpr index_t | kAlignmentK = Policy::template GetAlignmentK<Problem>() |
| static constexpr index_t | kAlignmentV |
| static constexpr index_t | kAlignmentO = Policy::template GetAlignmentO<Problem>() |
| static constexpr index_t | kAlignmentBias |
| static constexpr index_t | kBlockPerCu |
| static constexpr const char * | name = "qr_async" |
Member Typedef Documentation
◆ AttentionVariant
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::AttentionVariant = remove_cvref_t<typename Problem::AttentionVariant> |
◆ BiasDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::BiasDataType = remove_cvref_t<typename Problem::BiasDataType> |
◆ BlockFmhaShape
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape> |
◆ DropoutType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout> |
◆ FmhaMask
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::FmhaMask = remove_cvref_t<typename Problem::FmhaMask> |
◆ KDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::KDataType = remove_cvref_t<typename Problem::KDataType> |
◆ LSEDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::LSEDataType = remove_cvref_t<typename Problem::LSEDataType> |
◆ OaccDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::OaccDataType = remove_cvref_t<typename Problem::OaccDataType> |
◆ ODataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::ODataType = remove_cvref_t<typename Problem::ODataType> |
◆ PDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::PDataType = remove_cvref_t<typename Problem::PDataType> |
◆ Policy
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::Policy = remove_cvref_t<Policy_> |
◆ Problem
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::Problem = remove_cvref_t<Problem_> |
◆ QDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::QDataType = remove_cvref_t<typename Problem::QDataType> |
◆ RandValOutputDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType> |
◆ SaccDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::SaccDataType = remove_cvref_t<typename Problem::SaccDataType> |
◆ SMPLComputeDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType> |
◆ VDataType
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::VDataType = remove_cvref_t<typename Problem::VDataType> |
◆ VLayout
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
| using ck_tile::BlockFmhaPipelineQRKSVSAsync< Problem_, Policy_ >::VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout> |
Member Function Documentation
◆ GetSmemSize()
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
inlinestaticconstexpr |
◆ operator()() [1/2]
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename PositionEncoding, typename AttentionVariantParams, typename BlockIndices>
|
inline |
◆ operator()() [2/2]
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
template<typename QDramBlockWindowTmp, typename KDramBlockWindowTmp, typename VDramBlockWindowTmp, typename BiasDramBlockWindowTmp, typename RandValDramBlockWindowTmp, typename LSEDramBlockWindowTmp, typename QElementFunction, typename KElementFunction, typename VElementFunction, typename BiasElementFunction, typename LSEElementFunction, typename SAccElementFunction, typename PComputeElementFunction, typename OAccElementFunction, typename PositionEncoding, typename AttentionVariantParams, typename BlockIndices>
|
inline |
NOTICE: bias might be materialized mask including -inf values, need consideration. alibi does not have this problem
Member Data Documentation
◆ BiasEnum
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ kAlignmentBias
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
Initial value:
=
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>()
static constexpr bool kPadSeqLenK
Definition block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:64
◆ kAlignmentK
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ kAlignmentO
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ kAlignmentQ
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ kAlignmentV
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
Initial value:
= []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}()
◆ kBlockPerCu
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ kBlockSize
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ kHasDropout
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ kHasLogitsSoftCap
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ kIsGroupMode
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ kK0
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ kK1
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ kM0
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ kN0
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ kN1
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ kPadHeadDimQ
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ kPadHeadDimV
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ kPadSeqLenK
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ kPadSeqLenQ
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ kQKHeaddim
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ kQLoadOnce
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ kStoreLSE
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ kSubQKHeaddim
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
◆ name
template<typename Problem_, typename Policy_ = BlockFmhaPipelineQRKSVSAsyncDefaultPolicy>
|
staticconstexpr |
The documentation for this struct was generated from the following file: