block_gemm_areg_breg_creg_v1.hpp Source File

block_gemm_areg_breg_creg_v1.hpp Source File#

Composable Kernel: block_gemm_areg_breg_creg_v1.hpp Source File
block_gemm_areg_breg_creg_v1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// A is block distributed tensor
12// B is block distributed tensor
13// C is block distributed tensor
14template <typename Problem_,
15 typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy,
16 bool TransposeC_ = false>
18{
19 private:
20 template <typename PipelineProblem_, typename GemmPolicy_>
21 struct GemmTraits_
22 {
29
30 static constexpr index_t kBlockSize = Problem::kBlockSize;
31
32 static constexpr index_t MPerBlock = BlockGemmShape::kM;
33 static constexpr index_t NPerBlock = BlockGemmShape::kN;
34 static constexpr index_t KPerBlock = BlockGemmShape::kK;
35
36 static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
37 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
38
39 static constexpr index_t MWarp = config.template at<1>();
40 static constexpr index_t NWarp = config.template at<2>();
41 static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
42 static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
43 static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
44
45 static constexpr index_t KPack = WarpGemm::kKPerThread;
46 };
47
48 public:
51 static constexpr bool TransposeC = TransposeC_;
52
53 using Traits = GemmTraits_<Problem, Policy>;
54
55 using WarpGemm = typename Traits::WarpGemm;
56 using BlockGemmShape = typename Traits::BlockGemmShape;
57
61
62 static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
63 static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
64 static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
65
66 static constexpr index_t MWarp = Traits::MWarp;
67 static constexpr index_t NWarp = Traits::NWarp;
68 static constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1);
69
71 {
72 if constexpr(UseDefaultScheduler)
73 {
74 constexpr auto a_block_outer_dstr_encoding =
77 tuple<>,
78 tuple<>,
81
82 constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
83 a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
84
85 return a_block_dstr_encode;
86 }
87 else
88 {
89 constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
96 constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
97 a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
98
99 return a_block_dstr_encode;
100 }
101 }
102
104 {
105 if constexpr(UseDefaultScheduler)
106 {
107 constexpr auto b_block_outer_dstr_encoding =
110 tuple<>,
111 tuple<>,
114 constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
115 b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
116
117 return b_block_dstr_encode;
118 }
119 else
120 {
121 constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
128 constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
129 b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
130
131 return b_block_dstr_encode;
132 }
133 }
134
136 {
137 using c_distr_ys_major = std::conditional_t<TransposeC, sequence<2, 1>, sequence<1, 2>>;
138 if constexpr(UseDefaultScheduler)
139 {
140 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
143 tuple<>,
144 tuple<>,
145 c_distr_ys_major,
147 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
148 c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
149
150 return c_block_dstr_encode;
151 }
152 else
153 {
154 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
159 c_distr_ys_major,
161 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
162 c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
163
164 return c_block_dstr_encode;
165 }
166 }
167
168 // C += A * B
169 template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor>
170 CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
171 const ABlockTensor& a_block_tensor,
172 const BBlockTensor& b_block_tensor) const
173 {
174 static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
175 std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
176 std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
177 "wrong!");
178
179 // check ABC-block-distribution
180 static_assert(
181 std::is_same_v<remove_cvref_t<decltype(MakeABlockDistributionEncode())>,
182 remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
183 .get_static_tile_distribution_encoding())>>,
184 "A distribution is wrong!");
185 static_assert(
186 std::is_same_v<remove_cvref_t<decltype(MakeBBlockDistributionEncode())>,
187 remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
188 .get_static_tile_distribution_encoding())>>,
189 "B distribution is wrong!");
190 static_assert(
191 std::is_same_v<remove_cvref_t<decltype(MakeCBlockDistributionEncode())>,
192 remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
193 .get_static_tile_distribution_encoding())>>,
194 "C distribution is wrong!");
195
196 using AWarpDstr = typename WarpGemm::AWarpDstr;
197 using BWarpDstr = typename WarpGemm::BWarpDstr;
198 using CWarpDstr = typename WarpGemm::CWarpDstr;
199
200 using AWarpTensor = typename WarpGemm::AWarpTensor;
201 using BWarpTensor = typename WarpGemm::BWarpTensor;
202 using CWarpTensor = typename WarpGemm::CWarpTensor;
203
204 constexpr auto a_warp_y_lengths =
205 to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
206 constexpr auto b_warp_y_lengths =
207 to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
208 constexpr auto c_warp_y_lengths =
209 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
210
211 constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
212 constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
213 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
214
215 // hot loop:
216 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
217 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
218 // read A warp tensor from A Block window
219 AWarpTensor a_warp_tensor;
220 a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
221 merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
222 merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
223
224 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
225 // read B warp tensor from B block tensor
226 BWarpTensor b_warp_tensor;
227 b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
228 merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
229 merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
230
231 // read C warp tensor from C block tensor
232 using c_iter_idx = std::
233 conditional_t<TransposeC, sequence<nIter, mIter>, sequence<mIter, nIter>>;
234 CWarpTensor c_warp_tensor;
235 c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
236 merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
237 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
238
239 // warp GEMM
240 WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
241
242 // write C warp tensor into C block tensor
243 c_block_tensor.set_y_sliced_thread_data(
244 merge_sequences(c_iter_idx{}, c_warp_y_index_zeros),
245 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
246 c_warp_tensor.get_thread_buffer());
247 });
248 });
249 });
250 }
251
252 CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
253 {
254 using c_distr_ys_major = std::conditional_t<TransposeC, sequence<2, 1>, sequence<1, 2>>;
255 if constexpr(UseDefaultScheduler)
256 {
257 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
260 tuple<>,
261 tuple<>,
262 c_distr_ys_major,
264
265 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
266 c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
267 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
268 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
269 return c_block_tensor;
270 }
271 else
272 {
273 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
278 c_distr_ys_major,
280
281 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
282 c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
283 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
284 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
285 return c_block_tensor;
286 }
287 }
288
289 // C = A * B
290 template <typename ABlockTensor, typename BBlockTensor>
291 CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor,
292 const BBlockTensor& b_block_tensor) const
293 {
294 auto c_block_tensor = MakeCBlockTile();
295 operator()(c_block_tensor, a_block_tensor, b_block_tensor);
296 return c_block_tensor;
297 }
298};
299
300} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
Definition block_gemm_areg_breg_creg_v1_default_policy.hpp:14
Definition block_gemm_areg_breg_creg_v1.hpp:18
static constexpr index_t KIterPerWarp
Definition block_gemm_areg_breg_creg_v1.hpp:62
CK_TILE_DEVICE auto operator()(const ABlockTensor &a_block_tensor, const BBlockTensor &b_block_tensor) const
Definition block_gemm_areg_breg_creg_v1.hpp:291
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
Definition block_gemm_areg_breg_creg_v1.hpp:252
GemmTraits_< Problem, Policy > Traits
Definition block_gemm_areg_breg_creg_v1.hpp:53
remove_cvref_t< Policy_ > Policy
Definition block_gemm_areg_breg_creg_v1.hpp:50
typename Traits::WarpGemm WarpGemm
Definition block_gemm_areg_breg_creg_v1.hpp:55
remove_cvref_t< typename Traits::CDataType > CDataType
Definition block_gemm_areg_breg_creg_v1.hpp:60
static constexpr bool UseDefaultScheduler
Definition block_gemm_areg_breg_creg_v1.hpp:68
static CK_TILE_DEVICE constexpr auto MakeCBlockDistributionEncode()
Definition block_gemm_areg_breg_creg_v1.hpp:135
static constexpr index_t MIterPerWarp
Definition block_gemm_areg_breg_creg_v1.hpp:63
static constexpr index_t NWarp
Definition block_gemm_areg_breg_creg_v1.hpp:67
static constexpr bool TransposeC
Definition block_gemm_areg_breg_creg_v1.hpp:51
static CK_TILE_DEVICE constexpr auto MakeABlockDistributionEncode()
Definition block_gemm_areg_breg_creg_v1.hpp:70
remove_cvref_t< typename Traits::BDataType > BDataType
Definition block_gemm_areg_breg_creg_v1.hpp:59
remove_cvref_t< typename Traits::ADataType > ADataType
Definition block_gemm_areg_breg_creg_v1.hpp:58
typename Traits::BlockGemmShape BlockGemmShape
Definition block_gemm_areg_breg_creg_v1.hpp:56
static constexpr index_t NIterPerWarp
Definition block_gemm_areg_breg_creg_v1.hpp:64
remove_cvref_t< Problem_ > Problem
Definition block_gemm_areg_breg_creg_v1.hpp:49
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, const ABlockTensor &a_block_tensor, const BBlockTensor &b_block_tensor) const
Definition block_gemm_areg_breg_creg_v1.hpp:170
static constexpr index_t MWarp
Definition block_gemm_areg_breg_creg_v1.hpp:66
static CK_TILE_DEVICE constexpr auto MakeBBlockDistributionEncode()
Definition block_gemm_areg_breg_creg_v1.hpp:103
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192