moe_smoothquant_kernel.hpp Source File

moe_smoothquant_kernel.hpp Source File#

Composable Kernel: moe_smoothquant_kernel.hpp Source File
moe_smoothquant_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// host side args
13{
14 const void* p_x; // [tokens ,hidden_size], input, fp16/bf16
15 const void* p_smscale; // [experts, hidden_size], input, columnwise scale, fp32
16 const void* p_topk_ids; // [tokens, topk]
17
18 void* p_yscale; // [topk * tokens, 1], output, rowwise quant scale
19 void* p_qy; // [topk * tokens, hidden_size], output
20
25 index_t x_stride; // input x row stride
26 index_t y_stride; // output y stride(stride for topk)
27};
28
29// TODO: Extract some type to wrapper class
30template <typename Pipeline_>
32{
34 using Problem = typename Pipeline::Problem;
35
41
42 static constexpr index_t Block_M = Problem::BlockShape::Block_M;
43 static constexpr index_t Block_N = Problem::BlockShape::Block_N;
44 static constexpr bool kPadM = false; // always no need to pad along M
45 static constexpr bool kPadN = Problem::kPadN;
46 static constexpr bool kTwoPass = Problem::kTwoPass;
47
48 static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
49 static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
50 static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
51 static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
52
53 static constexpr auto I0 = number<0>{};
54 static constexpr auto I1 = number<1>{};
55
56 static_assert(Problem::BlockShape::Repeat_M == 1);
57
58 struct Kargs
59 {
60 const void* p_x; // [tokens ,hidden_size], input, fp16/bf16
61 const void* p_smscale; // [experts, hidden_size], input, columnwise scale, fp32
62 const void* p_topk_ids; // [tokens, topk]
63
64 void* p_yscale; // [topk, tokens, 1], output, rowwise quant scale
65 void* p_qy; // [topk, tokens, hidden_size], output
66
71 index_t x_stride; // input x row stride
72 index_t y_stride; // output y stride(stride for topk)
73 };
75
76 CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
77 {
78 return Kargs{hargs.p_x,
79 hargs.p_smscale,
80 hargs.p_topk_ids,
81 hargs.p_yscale,
82 hargs.p_qy,
83 hargs.tokens,
84 hargs.hidden_size,
85 hargs.experts,
86 hargs.topk,
87 hargs.x_stride,
88 hargs.y_stride};
89 }
90
91 CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
92 {
93 return dim3(hargs.topk, integer_divide_ceil(hargs.tokens, Block_M), 1);
94 }
95
96 CK_TILE_HOST static constexpr auto BlockSize()
97 {
98 return is_wave32() ? Problem::BlockShape::template GetBlockSize<true>()
99 : Problem::BlockShape::template GetBlockSize<false>();
100 }
101
102 // clang-format off
103 template <typename T> struct t2s;
104 template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
105 template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
106 template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
107 template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
108 template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
109 template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "i8"; };
110 // clang-format on
111
112 // in byte
113 CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
114
115 CK_TILE_HOST static std::string GetName()
116 {
117 // clang-format off
118 using S_ = typename Problem::BlockShape;
119 auto surfix = [&] () {
120 std::string n;
121 if (kPadN) n += "_pn";
122 if (kTwoPass) n += "_2p";
123 return n; }();
124
125 #define _SS_ std::string
126 #define _TS_ std::to_string
127 return _SS_("moe_smoothquant_") + _SS_(t2s<XDataType>::name) + "_" + _SS_(t2s<QYDataType>::name) + "_" +
128 _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
129 _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
130 _SS_(Pipeline::name) + surfix;
131 #undef _SS_
132 #undef _TS_
133 // clang-format on
134 }
135
137 {
138 const index_t i_topk = blockIdx.x;
139 const index_t i_token = blockIdx.y * Block_M;
140 const index_t i_token_in_thrd =
141 amd_wave_read_first_lane(threadIdx.x / Problem::BlockShape::ThreadPerBlock_N);
142
143 const index_t i_expert = reinterpret_cast<const index_t*>(
144 kargs.p_topk_ids)[(i_token + i_token_in_thrd) * kargs.topk + i_topk];
145
146 // [tokens ,hidden_size]
147 const auto x_window = [&]() {
149 static_cast<const XDataType*>(kargs.p_x),
150 make_tuple(kargs.tokens, kargs.hidden_size),
151 make_tuple(kargs.x_stride, 1),
153 number<1>{});
154
155 const auto tmp2_ = pad_tensor_view(
157 return make_tile_window(
158 tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {i_token, 0});
159 }();
160
161 // [experts, hidden_size],
162 const auto smscale_window = [&]() {
164 static_cast<const SmoothScaleDataType*>(kargs.p_smscale) +
165 i_expert * kargs.hidden_size,
166 make_tuple(kargs.hidden_size),
167 make_tuple(1),
169 number<1>{});
170
171 const auto tmp2_ =
173
174 return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
175 }();
176
177 // [topk, tokens]
178 auto yscale_window = [&]() {
180 static_cast<YScaleDataType*>(kargs.p_yscale) + i_topk * kargs.tokens,
181 make_tuple(kargs.tokens),
182 make_tuple(1),
183 number<1>{});
184
185 const auto tmp2_ =
187
188 return make_tile_window(tmp2_, make_tuple(number<Block_M>{}), {i_token});
189 }();
190
191 // [topk, tokens, hidden_size]
192 auto qy_window = [&]() {
194 static_cast<QYDataType*>(kargs.p_qy) + i_topk * kargs.tokens * kargs.y_stride,
195 make_tuple(kargs.tokens, kargs.hidden_size),
196 make_tuple(kargs.y_stride, 1),
198 number<1>{});
199
200 auto tmp2_ = pad_tensor_view(
202 return make_tile_window(
203 tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {i_token, 0});
204 }();
205
206 __shared__ char smem[GetSmemSize()];
207
208 Pipeline{}(x_window, smscale_window, yscale_window, qy_window, kargs.hidden_size, smem);
209 }
210};
211
212} // namespace ck_tile
#define _TS_
#define _SS_
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
int8_t int8_t
Definition int8.hpp:20
bfloat16_t bf16_t
Definition bfloat16.hpp:113
_Float16 fp16_t
Definition half.hpp:110
_BitInt(8) fp8_t
Definition float8.hpp:204
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
unsigned _BitInt(8) bf8_t
Definition float8.hpp:206
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
Definition moe_smoothquant_kernel.hpp:59
index_t y_stride
Definition moe_smoothquant_kernel.hpp:72
const void * p_x
Definition moe_smoothquant_kernel.hpp:60
index_t tokens
Definition moe_smoothquant_kernel.hpp:67
index_t topk
Definition moe_smoothquant_kernel.hpp:70
index_t x_stride
Definition moe_smoothquant_kernel.hpp:71
void * p_yscale
Definition moe_smoothquant_kernel.hpp:64
const void * p_smscale
Definition moe_smoothquant_kernel.hpp:61
void * p_qy
Definition moe_smoothquant_kernel.hpp:65
index_t experts
Definition moe_smoothquant_kernel.hpp:69
const void * p_topk_ids
Definition moe_smoothquant_kernel.hpp:62
index_t hidden_size
Definition moe_smoothquant_kernel.hpp:68
static constexpr const char * name
Definition moe_smoothquant_kernel.hpp:106
static constexpr const char * name
Definition moe_smoothquant_kernel.hpp:108
static constexpr const char * name
Definition moe_smoothquant_kernel.hpp:105
static constexpr const char * name
Definition moe_smoothquant_kernel.hpp:107
static constexpr const char * name
Definition moe_smoothquant_kernel.hpp:109
static constexpr const char * name
Definition moe_smoothquant_kernel.hpp:104
Definition moe_smoothquant_kernel.hpp:103
Definition moe_smoothquant_kernel.hpp:13
index_t x_stride
Definition moe_smoothquant_kernel.hpp:25
index_t topk
Definition moe_smoothquant_kernel.hpp:24
index_t hidden_size
Definition moe_smoothquant_kernel.hpp:22
void * p_yscale
Definition moe_smoothquant_kernel.hpp:18
index_t experts
Definition moe_smoothquant_kernel.hpp:23
index_t y_stride
Definition moe_smoothquant_kernel.hpp:26
index_t tokens
Definition moe_smoothquant_kernel.hpp:21
const void * p_topk_ids
Definition moe_smoothquant_kernel.hpp:16
const void * p_smscale
Definition moe_smoothquant_kernel.hpp:15
void * p_qy
Definition moe_smoothquant_kernel.hpp:19
const void * p_x
Definition moe_smoothquant_kernel.hpp:14
Definition moe_smoothquant_kernel.hpp:32
MoeSmoothquantHostArgs Hargs
Definition moe_smoothquant_kernel.hpp:74
static constexpr bool kTwoPass
Definition moe_smoothquant_kernel.hpp:46
remove_cvref_t< typename Problem::SmoothScaleDataType > SmoothScaleDataType
Definition moe_smoothquant_kernel.hpp:37
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition moe_smoothquant_kernel.hpp:136
remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition moe_smoothquant_kernel.hpp:39
static constexpr bool kPadM
Definition moe_smoothquant_kernel.hpp:44
static constexpr auto I0
Definition moe_smoothquant_kernel.hpp:53
static constexpr bool kPadN
Definition moe_smoothquant_kernel.hpp:45
static CK_TILE_HOST constexpr auto GridSize(const Hargs &hargs)
Definition moe_smoothquant_kernel.hpp:91
remove_cvref_t< typename Problem::QYDataType > QYDataType
Definition moe_smoothquant_kernel.hpp:40
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition moe_smoothquant_kernel.hpp:38
remove_cvref_t< Pipeline_ > Pipeline
Definition moe_smoothquant_kernel.hpp:33
remove_cvref_t< typename Problem::XDataType > XDataType
Definition moe_smoothquant_kernel.hpp:36
static constexpr index_t Vector_N
Definition moe_smoothquant_kernel.hpp:49
static CK_TILE_HOST constexpr auto BlockSize()
Definition moe_smoothquant_kernel.hpp:96
static CK_TILE_HOST constexpr Kargs MakeKargs(const Hargs &hargs)
Definition moe_smoothquant_kernel.hpp:76
static constexpr index_t Block_N
Definition moe_smoothquant_kernel.hpp:43
static CK_TILE_HOST std::string GetName()
Definition moe_smoothquant_kernel.hpp:115
typename Pipeline::Problem Problem
Definition moe_smoothquant_kernel.hpp:34
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition moe_smoothquant_kernel.hpp:113
static constexpr index_t Repeat_N
Definition moe_smoothquant_kernel.hpp:50
static constexpr index_t ThreadPerWarp_N
Definition moe_smoothquant_kernel.hpp:48
static constexpr index_t kBlockSize
Definition moe_smoothquant_kernel.hpp:51
static constexpr auto I1
Definition moe_smoothquant_kernel.hpp:54
static constexpr index_t Block_M
Definition moe_smoothquant_kernel.hpp:42
Definition tile/core/container/sequence.hpp:49