block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp Source File

block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp Source File#

Composable Kernel: block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp Source File
block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10
11namespace ck_tile {
12
13// This pipeline is qkv all located in LDS
15 : BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
16 /* AsyncCopy = */ false,
17 /* NumPrefetchK = */ 1,
18 /* NumPrefetchV = */ 1>
19{
20 using BasePolicy = BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
21 /* AsyncCopy = */ false,
22 /* NumPrefetchK = */ 1,
23 /* NumPrefetchV = */ 1>;
24
25 template <typename Problem>
26 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
27 {
28 constexpr index_t kBlockSize = Problem::kBlockSize;
29 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
30 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
31
32 constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
33
34 // this should align with MakeQDramTileDistribution()
35 constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
36 static_assert(0 < ElemPerThread);
37 return min(ElemPerThread, MaxVectorSize);
38 }
39
40 template <typename Problem>
41 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
42 {
44
45 return static_cast<index_t>(16 / sizeof(OaccDataType));
46 }
47
48 template <typename Problem>
50 {
51 constexpr index_t kBlockSize = Problem::kBlockSize;
52 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
53 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
54
55 constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
56
57 constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
58 static_assert(0 < ElemPerThread);
59 constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
60
61 constexpr index_t KPerThread = kMaxVecLoad;
62 constexpr index_t KThreads = kKPerBlock / KPerThread;
63 constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
64 constexpr index_t NumWarps = kBlockSize / get_warp_size();
65 constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
66
75 }
76
77 template <typename Problem>
79 {
80 return BasePolicy::template MakeQRegTileDistribution<Problem>();
81 }
82
83 template <typename Problem>
84 CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
85 {
86 // TODO: this is for 3d layout
88 return static_cast<index_t>(16 / sizeof(QDataType));
89 }
90
91 template <typename Problem>
93 {
94 constexpr index_t kBlockSize = Problem::kBlockSize;
95 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
96 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
97
98 constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
99 static_assert(0 < ElemPerThread);
100 constexpr index_t kKPack = min(ElemPerThread, GetSmemKPackQ<Problem>());
101
102 constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor(
104 make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
106 number<1>{});
107
108 constexpr auto q_lds_block_desc = transform_tensor_descriptor(
109 q_lds_block_desc_0,
115
116 return q_lds_block_desc;
117 }
118
119 template <typename Problem>
120 CK_TILE_HOST_DEVICE static constexpr auto GetSmemNPackS()
121 {
123 return static_cast<index_t>(16 / sizeof(SDataType));
124 }
125
126 template <typename Problem>
128 {
129 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
130 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
131 constexpr index_t kNPack = GetSmemNPackS<Problem>();
132
133 constexpr auto s_lds_block_desc_0 = make_naive_tensor_descriptor(
135 make_tuple(number<(kMPerBlock + 1) * kNPack>{}, number<kNPack>{}, number<1>{}),
137 number<1>{});
138
139 constexpr auto s_lds_block_desc = transform_tensor_descriptor(
140 s_lds_block_desc_0,
146
147 return s_lds_block_desc;
148 }
149
150 template <typename Problem>
152 {
154
155 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
156 using WG = remove_cvref_t<decltype(config.template at<0>())>;
157 constexpr index_t MWarp = config.template at<1>();
158 constexpr index_t NWarp = config.template at<2>();
159
160 static_assert(MWarp == 1, "Check failed!");
161
162 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
163 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
164 constexpr index_t kTileK = Problem::BlockFmhaShape::kN0;
165
166 // K2 is equal to Impl::kABKPerLane * kKIterPerWarpGemm
167 constexpr index_t K3 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
168 constexpr index_t K2 = WG::WarpGemmAttribute::Impl::kABKLane;
169 constexpr index_t K1 = kKPerBlock / (K2 * K3);
170 constexpr index_t K0 = kTileK / kKPerBlock;
171 constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane;
172 constexpr index_t M1 = MWarp;
173 constexpr index_t M0 = kMPerBlock / (M2 * M1);
174
175 constexpr auto s2_block_dstr_encoding =
182
183 constexpr auto s2_block_dstr = make_static_tile_distribution(s2_block_dstr_encoding);
184
185 return s2_block_dstr;
186 }
187
188 template <typename Problem>
190 {
191 return MakeQLdsBlockDescriptor<Problem>().get_element_space_size() *
192 sizeof(typename Problem::QDataType);
193 }
194
195 template <typename Problem>
197 {
198 return MakeKLdsBlockDescriptor<Problem>().get_element_space_size() *
199 sizeof(typename Problem::KDataType);
200 }
201
202 template <typename Problem>
204 {
205 return MakeVLdsBlockDescriptor<Problem>().get_element_space_size() *
206 sizeof(typename Problem::VDataType);
207 }
208
209 template <typename Problem>
211 {
212 return MakeSLdsBlockDescriptor<Problem>().get_element_space_size() *
213 sizeof(typename Problem::SaccDataType);
214 }
215
216 template <typename Problem>
222};
223
224} // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp:19
static CK_TILE_HOST_DEVICE constexpr auto MakeSLdsBlockDescriptor()
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp:127
BlockFmhaPipelineQXKSVSCustomPolicy< true, false, 1, 1 > BasePolicy
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp:20
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentOacc()
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp:41
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeQ()
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp:189
static CK_TILE_HOST_DEVICE constexpr auto MakeQLdsBlockDescriptor()
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp:92
static CK_TILE_HOST_DEVICE constexpr auto MakeSRegTileDistribution()
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp:151
static CK_TILE_HOST_DEVICE constexpr auto MakeQRegTileDistribution()
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp:78
static CK_TILE_HOST_DEVICE constexpr auto MakeQDramTileDistribution()
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp:49
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackQ()
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp:84
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp:217
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeK()
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp:196
static CK_TILE_HOST_DEVICE constexpr auto GetSmemNPackS()
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp:120
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeV()
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp:203
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentQ()
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp:26
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSizeS()
Definition block_fmha_fwd_splitkv_pipeline_nwarp_sshuffle_qr_ks_vs_default_policy.hpp:210
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:266
static CK_TILE_HOST_DEVICE constexpr auto MakeVLdsBlockDescriptor()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:620
static CK_TILE_HOST_DEVICE constexpr auto MakeKLdsBlockDescriptor()
Definition block_fmha_pipeline_qx_ks_vs_custom_policy.hpp:486
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192