StandardAttention Struct Reference

StandardAttention Struct Reference#

Composable Kernel: ck_tile::StandardAttention Struct Reference
ck_tile::StandardAttention Struct Reference

#include <variants.hpp>

Public Member Functions

__device__ __host__ StandardAttention ()=default
template<typename Params, typename T>
__device__ __forceinline__ T QueryTransform (const Params &params, T q) const
template<typename Params, typename T>
__device__ __forceinline__ T LogitsTransform (const Params &params, T logits, uint32_t batch_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) const
template<typename Params>
__device__ __forceinline__ bool LogitsMask (const Params &params, uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) const

Constructor & Destructor Documentation

◆ StandardAttention()

__device__ __host__ ck_tile::StandardAttention::StandardAttention ( )
default

Member Function Documentation

◆ LogitsMask()

template<typename Params>
__device__ __forceinline__ bool ck_tile::StandardAttention::LogitsMask ( const Params & params,
uint32_t batch_idx,
uint32_t qo_idx,
uint32_t kv_idx,
uint32_t qo_head_idx,
uint32_t kv_head_idx ) const
inline

◆ LogitsTransform()

template<typename Params, typename T>
__device__ __forceinline__ T ck_tile::StandardAttention::LogitsTransform ( const Params & params,
T logits,
uint32_t batch_idx,
uint32_t qo_head_idx,
uint32_t kv_head_idx ) const
inline

NOTICE: For better performance, we simpliy transform thread buffer without calculating qo_idx/kv_idx.

◆ QueryTransform()

template<typename Params, typename T>
__device__ __forceinline__ T ck_tile::StandardAttention::QueryTransform ( const Params & params,
T q ) const
inline

The documentation for this struct was generated from the following file: