rmsnorm2d_fwd_kernel.hpp Source File

rmsnorm2d_fwd_kernel.hpp Source File#

Composable Kernel: rmsnorm2d_fwd_kernel.hpp Source File
rmsnorm2d_fwd_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
12// host side args
14{
15 const void* p_x; // [m ,n], input, fp16/bf16
16 const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
17 const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
18 const void* p_gamma; // [1, n], gamma, prec same as input
19
20 void* p_y; // [m, n], output, fp16/bf16
21 void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
22 void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
23 void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used
24 void* p_y_unquant; // [m, n], output result before quant, nullptr if not used
25
26 float epsilon;
27
30 index_t x_stride; // x row_stride
31 index_t xr_stride; // x residule row stride
32 index_t y_stride; // y row stride
33 index_t yr_stride; // y residule row stride
34};
35
36// TODO: Extract some type to wrapper class
37template <typename Pipeline_, typename Epilogue_>
39{
42 using Problem = typename Pipeline::Problem;
43
52
53 // for simplicity, shortcut input/output type is same as X
56
57 static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
58 static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
59 static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
60
61 static constexpr index_t Block_M = Problem::BlockShape::Block_M;
62 static constexpr index_t Block_N = Problem::BlockShape::Block_N;
63 static constexpr bool kPadM = false; // always no need to pad along M
64 static constexpr bool kPadN = Problem::Traits::kPadN;
65 static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
66 static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
67 static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
68 static constexpr auto kUseModelSensitiveRMSNorm = Problem::Traits::kUseModelSensitiveRMSNorm;
69
70 static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
71 static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
72 static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
73 static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
74
75 static constexpr auto I0 = number<0>{};
76 static constexpr auto I1 = number<1>{};
77
78 struct Kargs
79 {
80 const void* p_x;
81 const void* p_x_residual;
82 const void* p_sm_scale;
83 const void* p_gamma;
84
85 void* p_y;
87 void* p_y_scale;
88 void* p_invRms;
90
91 float epsilon;
92
95 index_t x_stride; // x row_stride
96 index_t xr_stride; // x residule row stride
97 index_t y_stride; // y row stride
98 index_t yr_stride; // y residule row stride
99 };
101
102 CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
103 {
104 return Kargs{hargs.p_x,
105 hargs.p_x_residual,
106 hargs.p_sm_scale,
107 hargs.p_gamma,
108 hargs.p_y,
109 hargs.p_y_residual,
110 hargs.p_y_scale,
111 hargs.p_invRms,
112 hargs.p_y_unquant,
113 hargs.epsilon,
114 hargs.m,
115 hargs.n,
116 hargs.x_stride,
117 hargs.xr_stride,
118 hargs.y_stride,
119 hargs.yr_stride};
120 }
121
122 CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
123 {
124 return dim3(integer_divide_ceil(hargs.m, Block_M));
125 }
126
127 CK_TILE_HOST static constexpr auto BlockSize()
128 {
129 return is_wave32() ? Problem::BlockShape::template GetBlockSize<true>()
130 : Problem::BlockShape::template GetBlockSize<false>();
131 }
132
133 // clang-format off
134 template <typename T> struct t2s;
135 template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
136 template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
137 template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
138 template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
139 template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
140 template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "int8"; };
141 // clang-format on
142
143 // in byte
144 CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
145
146 CK_TILE_HOST static std::string GetName()
147 {
148#define _SS_ std::string
149#define _TS_ std::to_string
150 // clang-format off
151 using S_ = typename Problem::BlockShape;
152 auto surfix = [&] () {
153 std::string n;
156 if (kPadN) n += "_pn";
157 if (kSaveInvRms) n += "_rms";
158 if (kTwoPass) n += "_2p";
161 return n; }();
162
163 auto prec_str = [&] () {
164 std::string base_str = _SS_(t2s<XDataType>::name);
165 if (!std::is_same_v<XDataType, YDataType>) {
166 base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
167 }
169 base_str += _SS_("_sx") + _SS_(t2s<SmoothScaleDataType>::name);
170 base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
171 }
173 base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
174 }
175 return base_str;
176 }();
177
178 return _SS_("rmsnorm2d_fwd_") + _SS_(prec_str) + "_" +
179 _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
180 _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
181 _SS_(Pipeline::name) + surfix;
182 // clang-format on
183#undef _SS_
184#undef _TS_
185 }
186
188 {
189 const auto iM = get_block_id() * Block_M;
190
191 const auto x_window = [&]() {
193 static_cast<const XDataType*>(kargs.p_x),
194 make_tuple(kargs.m, kargs.n),
195 make_tuple(kargs.x_stride, 1),
197 number<1>{});
198
199 const auto tmp2_ = pad_tensor_view(
201 return make_tile_window(
202 tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
203 }();
204
205 const auto x_residual_window = [&]() {
208 {
210 static_cast<const XResidualDataType*>(kargs.p_x_residual),
211 make_tuple(kargs.m, kargs.n),
212 make_tuple(kargs.xr_stride, 1),
214 number<1>{});
215
216 const auto tmp2_ = pad_tensor_view(tmp_,
219 return make_tile_window(
220 tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
221 }
222 else
223 {
225 }
226 }();
227
228 const auto gamma_window = [&]() {
230 static_cast<const GammaDataType*>(kargs.p_gamma),
231 make_tuple(kargs.n),
232 make_tuple(1),
234 number<1>{});
235
236 const auto tmp2_ =
238
239 return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
240 }();
241
242 auto y_window = [&]() {
244 static_cast<YDataType*>(kargs.p_y),
245 make_tuple(kargs.m, kargs.n),
246 make_tuple(kargs.y_stride, 1),
248 number<1>{});
249
250 auto tmp2_ = pad_tensor_view(
252 return make_tile_window(
253 tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
254 }();
255
256 auto y_residual_window = [&]() {
258 {
260 static_cast<YResidualDataType*>(kargs.p_y_residual),
261 make_tuple(kargs.m, kargs.n),
262 make_tuple(kargs.yr_stride, 1),
264 number<1>{});
265
266 auto tmp2_ = pad_tensor_view(tmp_,
269 return make_tile_window(
270 tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
271 }
272 else
273 {
275 }
276 }();
277
278 auto inv_rms_window = [&]() {
279 if constexpr(kSaveInvRms)
280 {
281 const auto inv_rms_m = [&]() {
282 const auto inv_rms_dram_naive =
284 static_cast<InvRmsDataType*>(kargs.p_invRms),
285 make_tuple(kargs.m),
286 number<1>{});
287
288 return pad_tensor_view(
289 inv_rms_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
290 }();
291 return make_tile_window(inv_rms_m, make_tuple(number<Block_M>{}), {iM});
292 }
293 else
295 }();
296
297 auto sm_scale_window = [&]() {
299 {
300 const auto win_ = [&]() {
302 static_cast<const SmoothScaleDataType*>(kargs.p_sm_scale),
303 make_tuple(kargs.n),
305
306 return pad_tensor_view(tmp_0_,
308 sequence<false>{}); // sm_scale no need pad
309 }();
310 return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
311 }
312 else
313 {
315 }
316 }();
317
318 auto y_scale_window = [&]() {
321 {
322 const auto win_ = [&]() {
324 static_cast<YScaleDataType*>(kargs.p_y_scale),
325 make_tuple(kargs.m),
326 number<1>{});
327
328 return pad_tensor_view(
330 }();
331 return make_tile_window(win_, make_tuple(number<Block_M>{}), {iM});
332 }
333 else
334 {
336 }
337 }();
338
339 auto unquant_y_window = [&]() {
343 {
345 static_cast<UnquantYDataType*>(kargs.p_y_unquant),
346 make_tuple(kargs.m, kargs.n),
347 make_tuple(kargs.y_stride, 1),
349 number<1>{});
350
351 auto tmp2_ = pad_tensor_view(tmp_,
354 return make_tile_window(
355 tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
356 }
357 else
358 {
360 }
361 }();
362
363 __shared__ char smem[GetSmemSize()];
364
365 Pipeline{}(x_window,
366 x_residual_window,
367 gamma_window,
368 y_window,
369 y_residual_window,
370 inv_rms_window,
371 sm_scale_window,
372 y_scale_window,
373 unquant_y_window,
374 static_cast<const ComputeDataType>(kargs.epsilon),
375 kargs.n,
376 smem,
377 Epilogue{});
378 }
379};
380
381} // namespace ck_tile
#define _TS_
#define _SS_
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
@ NO_SPECIFIC_MODEL
Definition rmsnorm2d_fwd_traits.hpp:42
@ T5_MODEL_LIKE
Definition rmsnorm2d_fwd_traits.hpp:46
@ NO_SWEEP
Definition rmsnorm2d_fwd_traits.hpp:28
@ SMOOTH_DYNAMIC_QUANT
Definition rmsnorm2d_fwd_traits.hpp:29
@ DYNAMIC_QUANT
Definition rmsnorm2d_fwd_traits.hpp:30
int8_t int8_t
Definition int8.hpp:20
bfloat16_t bf16_t
Definition bfloat16.hpp:113
_Float16 fp16_t
Definition half.hpp:110
@ PRE_ADD_STORE
Definition rmsnorm2d_fwd_traits.hpp:14
@ PRE_ADD
Definition rmsnorm2d_fwd_traits.hpp:16
@ NO_ADD
Definition rmsnorm2d_fwd_traits.hpp:12
_BitInt(8) fp8_t
Definition float8.hpp:204
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths &window_lengths)
Definition null_tile_window.hpp:66
unsigned _BitInt(8) bf8_t
Definition float8.hpp:206
CK_TILE_DEVICE index_t get_block_id()
Definition arch.hpp:119
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view_packed(DataType *__restrict__ p, const tuple< Lengths... > &lengths, number< GuaranteedLastDimensionVectorLength >=number<-1 >{})
Definition tensor_view.hpp:494
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
Definition rmsnorm2d_fwd_traits.hpp:20
Definition rmsnorm2d_fwd_traits.hpp:34
Definition rmsnorm2d_fwd_kernel.hpp:79
void * p_invRms
Definition rmsnorm2d_fwd_kernel.hpp:88
void * p_y_scale
Definition rmsnorm2d_fwd_kernel.hpp:87
index_t n
Definition rmsnorm2d_fwd_kernel.hpp:94
const void * p_x
Definition rmsnorm2d_fwd_kernel.hpp:80
index_t yr_stride
Definition rmsnorm2d_fwd_kernel.hpp:98
index_t y_stride
Definition rmsnorm2d_fwd_kernel.hpp:97
void * p_y
Definition rmsnorm2d_fwd_kernel.hpp:85
index_t xr_stride
Definition rmsnorm2d_fwd_kernel.hpp:96
const void * p_sm_scale
Definition rmsnorm2d_fwd_kernel.hpp:82
void * p_y_residual
Definition rmsnorm2d_fwd_kernel.hpp:86
void * p_y_unquant
Definition rmsnorm2d_fwd_kernel.hpp:89
index_t m
Definition rmsnorm2d_fwd_kernel.hpp:93
const void * p_gamma
Definition rmsnorm2d_fwd_kernel.hpp:83
float epsilon
Definition rmsnorm2d_fwd_kernel.hpp:91
const void * p_x_residual
Definition rmsnorm2d_fwd_kernel.hpp:81
index_t x_stride
Definition rmsnorm2d_fwd_kernel.hpp:95
static constexpr const char * name
Definition rmsnorm2d_fwd_kernel.hpp:137
static constexpr const char * name
Definition rmsnorm2d_fwd_kernel.hpp:139
static constexpr const char * name
Definition rmsnorm2d_fwd_kernel.hpp:136
static constexpr const char * name
Definition rmsnorm2d_fwd_kernel.hpp:138
static constexpr const char * name
Definition rmsnorm2d_fwd_kernel.hpp:140
static constexpr const char * name
Definition rmsnorm2d_fwd_kernel.hpp:135
Definition rmsnorm2d_fwd_kernel.hpp:134
Definition rmsnorm2d_fwd_kernel.hpp:14
void * p_invRms
Definition rmsnorm2d_fwd_kernel.hpp:23
index_t xr_stride
Definition rmsnorm2d_fwd_kernel.hpp:31
void * p_y_residual
Definition rmsnorm2d_fwd_kernel.hpp:21
const void * p_x_residual
Definition rmsnorm2d_fwd_kernel.hpp:16
void * p_y_scale
Definition rmsnorm2d_fwd_kernel.hpp:22
float epsilon
Definition rmsnorm2d_fwd_kernel.hpp:26
void * p_y
Definition rmsnorm2d_fwd_kernel.hpp:20
index_t yr_stride
Definition rmsnorm2d_fwd_kernel.hpp:33
void * p_y_unquant
Definition rmsnorm2d_fwd_kernel.hpp:24
index_t x_stride
Definition rmsnorm2d_fwd_kernel.hpp:30
index_t y_stride
Definition rmsnorm2d_fwd_kernel.hpp:32
index_t m
Definition rmsnorm2d_fwd_kernel.hpp:28
index_t n
Definition rmsnorm2d_fwd_kernel.hpp:29
const void * p_sm_scale
Definition rmsnorm2d_fwd_kernel.hpp:17
const void * p_x
Definition rmsnorm2d_fwd_kernel.hpp:15
const void * p_gamma
Definition rmsnorm2d_fwd_kernel.hpp:18
Definition rmsnorm2d_fwd_kernel.hpp:39
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition rmsnorm2d_fwd_kernel.hpp:187
XDataType XResidualDataType
Definition rmsnorm2d_fwd_kernel.hpp:54
remove_cvref_t< typename Problem::UnquantYDataType > UnquantYDataType
Definition rmsnorm2d_fwd_kernel.hpp:51
remove_cvref_t< Epilogue_ > Epilogue
Definition rmsnorm2d_fwd_kernel.hpp:41
Rmsnorm2dFwdHostArgs Hargs
Definition rmsnorm2d_fwd_kernel.hpp:100
remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition rmsnorm2d_fwd_kernel.hpp:50
static constexpr bool kTwoPass
Definition rmsnorm2d_fwd_kernel.hpp:65
static constexpr bool kSaveInvRms
Definition rmsnorm2d_fwd_kernel.hpp:58
static constexpr auto I0
Definition rmsnorm2d_fwd_kernel.hpp:75
typename Pipeline::Problem Problem
Definition rmsnorm2d_fwd_kernel.hpp:42
remove_cvref_t< typename Problem::InvRmsDataType > InvRmsDataType
Definition rmsnorm2d_fwd_kernel.hpp:48
static constexpr bool kPadN
Definition rmsnorm2d_fwd_kernel.hpp:64
static CK_TILE_HOST std::string GetName()
Definition rmsnorm2d_fwd_kernel.hpp:146
static CK_TILE_HOST constexpr Kargs MakeKargs(const Hargs &hargs)
Definition rmsnorm2d_fwd_kernel.hpp:102
remove_cvref_t< typename Problem::YDataType > YDataType
Definition rmsnorm2d_fwd_kernel.hpp:47
static constexpr auto kFusedQuant
Definition rmsnorm2d_fwd_kernel.hpp:67
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition rmsnorm2d_fwd_kernel.hpp:46
remove_cvref_t< Pipeline_ > Pipeline
Definition rmsnorm2d_fwd_kernel.hpp:40
static CK_TILE_HOST constexpr auto BlockSize()
Definition rmsnorm2d_fwd_kernel.hpp:127
static constexpr auto I1
Definition rmsnorm2d_fwd_kernel.hpp:76
static CK_TILE_HOST constexpr auto GridSize(const Hargs &hargs)
Definition rmsnorm2d_fwd_kernel.hpp:122
static constexpr bool kPadM
Definition rmsnorm2d_fwd_kernel.hpp:63
remove_cvref_t< typename Problem::SmoothScaleDataType > SmoothScaleDataType
Definition rmsnorm2d_fwd_kernel.hpp:49
static constexpr index_t Block_M
Definition rmsnorm2d_fwd_kernel.hpp:61
static constexpr auto kFusedAdd
Definition rmsnorm2d_fwd_kernel.hpp:66
XDataType YResidualDataType
Definition rmsnorm2d_fwd_kernel.hpp:55
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition rmsnorm2d_fwd_kernel.hpp:144
static constexpr bool kHasGamma
Definition rmsnorm2d_fwd_kernel.hpp:57
remove_cvref_t< typename Problem::XDataType > XDataType
Definition rmsnorm2d_fwd_kernel.hpp:44
static constexpr index_t kBlockSize
Definition rmsnorm2d_fwd_kernel.hpp:73
static constexpr index_t ThreadPerWarp_N
Definition rmsnorm2d_fwd_kernel.hpp:70
remove_cvref_t< typename Problem::GammaDataType > GammaDataType
Definition rmsnorm2d_fwd_kernel.hpp:45
static constexpr index_t Block_N
Definition rmsnorm2d_fwd_kernel.hpp:62
static constexpr bool kSaveUnquant
Definition rmsnorm2d_fwd_kernel.hpp:59
static constexpr index_t Repeat_N
Definition rmsnorm2d_fwd_kernel.hpp:72
static constexpr auto kUseModelSensitiveRMSNorm
Definition rmsnorm2d_fwd_kernel.hpp:68
static constexpr index_t Vector_N
Definition rmsnorm2d_fwd_kernel.hpp:71
Definition tile/core/container/sequence.hpp:49