moe_sorting_kernel.hpp Source File

moe_sorting_kernel.hpp Source File#

Composable Kernel: moe_sorting_kernel.hpp Source File
moe_sorting_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10#include <string>
11#include <type_traits>
12
13#if !defined(CK_TILE_HAS_ROW_NEWBCAST)
14// row_newbcast (DPP modifier 0x157) support by architecture:
15// - Not supported: gfx908 (MI100) and older
16// - Supported: gfx90a (MI200), gfx94x (MI300), and all RDNA architectures
17
18#if defined(__HIP_DEVICE_COMPILE__) && defined(__HIP_PLATFORM_AMD__)
19#if defined(__gfx908__) || defined(__gfx906__) || defined(__gfx900__)
20// Explicitly disable for known unsupported architectures
21#define CK_TILE_HAS_ROW_NEWBCAST 0
22#else
23// Assume support for gfx90a and newer (including all gfx94x and RDNA)
24// This is safer as new architectures typically maintain backward compatibility
25#define CK_TILE_HAS_ROW_NEWBCAST 1
26#endif
27#else
28// Conservative default for non-AMD or host compilation
29#define CK_TILE_HAS_ROW_NEWBCAST 0
30#endif
31#endif
32
33namespace ck_tile {
34
35#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
36 static_cast<uint32_t>(((token_id_) & 0x00ffffff) | (((topk_id_) & 0xff) << 24))
37
38#ifndef MOE_SORTING_USE_EX_KERNEL
39#define MOE_SORTING_USE_EX_KERNEL 1
40#endif
41
42#ifndef MOE_SORTING_FUSE_MP_01
43#define MOE_SORTING_FUSE_MP_01 1
44#endif
45
46// weather use 2d buffer indexing for fmoe ws or 1d
47#ifndef MOE_SORTING_FMOE_2D_BUF
48#define MOE_SORTING_FMOE_2D_BUF 1
49#endif
50
51// clang-format off
52// [indexing implementation-1]
53// using M_a as constexpr block_size to partition all tokens into different slices
54// each slice map to one expert, and one expert can have multiple slices
55// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
56// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
57// tok-0 tok-1 tok-2 tok-3 tok-4
58// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float number)
59//
60// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
61// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
62// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
63//
64// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
65// * this could be larger than actual, since actual tokens are on GPU
66//
67// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
68// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
69// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
70//
71// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
72//
73// * Note on token_id_per_expert/sorted_token_ids_ptr data:
74// currently we do not have topk information from the data of token_id_per_expert/sorted_token_ids_ptr.
75// In some cases(like smooth-quant), we need topk information to indexing into tokens quant from
76// different expert smooth quant. So we modify the number stored inside token_id_per_expert/sorted_token_ids_ptr
77//
78// 32bit 0........23 24.....31 bit
79// (data) -> (token_id | topk_id)
80// low 24 bit is for token id, top 8 bit is for topk id
81//
82// the input after smooth-quant is [topk, token, hidden_dim], originally it is [token, hidden_dim]
83// the input scale for token is [topk, token, 1], the smooth-quant scale for first gemm is [expert, interm_dim]
84//
85// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
86// * length is (max_num_tokens_padded + block_size - 1) / block_size
87//
88// num_tokens_post_padded_ptr : [28]
89// num_sorted_tiles_ptr : [7]
90//
91// skip_experts_with_zero_tokens(SkipExpertsWithZeroTokens)
92// if enabled, the expert with no tokens will be skipped, in stead of padding to at least 1 unit_size(M_a)
93//
94// (pack below tensor, skip element marked with `-`)
95// Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y
96// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
97// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
98// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
99//
100//
101// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 5]
102// num_tokens_post_padded_ptr : [24]
103//
104// * local_expert_mask : indicate local expert mask used on current GPU (used for EP case)
105// and modify the output expert-ID, because we will only have enbaled expert on specific GPU.
106// we call expert input to this kernel as "global expert id", output as "local expert id"
107//
108// * local_expert_mask : [1, 0, 1, 1, 0, 1] (mask out expert-id=1, 4)
109//
110// (pack below tensor, skip element marked with `-`)
111// Y Y Y Y - - - - Y Y Y Y Y Y Y Y Y Y Y Y - - - - Y Y Y Y
112// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6, 0, 1, 2, 5]
113// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4 -|- exp-5 -|
114// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *, c, f, i, o]
115//
116// sorted_expert_ids_ptr : [0, 1, 2, 2, 3] (note original it was exper-id= 0, 2, 3, 5, but we produce "local expert id")
117// num_tokens_post_padded_ptr : [20]
118//
119// * different from vLLM
120// 1) token_id stored in sorted_token_ids_ptr is actual token_id, not token_id*top_K expanded id
121// 2)need sorted_weight_ptr
122// 3) use num_sorted_tiles_ptr, already divided by M_a
123//
124// * below used for indexing
125// 1) sorted_token_ids_ptr [max_num_tokens_padded]
126// 2) sorted_weight_ptr
127// 3) sorted_expert_ids_ptr
128// 4)num_tokens_post_padded_ptr/num_sorted_tiles_ptr (select one)
129//
130// max_num_tokens_padded: opk_ids.numel() + num_experts * (block_size - 1)
131
132
133CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int tokens_, int num_experts_)
134{
135 /* num_experts + 1
136 * +--------------------------------------+
137 * | |
138 * | |
139 * | | * -> sub-tokens
140 * | |
141 * | |
142 * +--------------------------------------+
143 * | | 2 -> cumsum buffer
144 * +--------------------------------------+
145 *
146 */
147 int smem_cols = num_experts_ + 1; // usually experts is power of 2. padding here
148 int smem_rows = [&](){
149 index_t target_occupancy_ = 2;
150 constexpr index_t total_ = get_smem_capacity() / sizeof(index_t);
151 constexpr index_t sub_unroll = 8;
152 constexpr index_t cumsum_bufs = 2; // 1 for cumsum, 1 for cnt
153 // at lease 2 lines, one for sub_token unroll, one for cumsum
154 // should be enough
155
156 int r = total_ / target_occupancy_ / smem_cols;
157
158 // Note: at lease allocate cumsum_bufs + sub_unroll as num-row. Otherwise, fallback to mp kernel
159 if(r < (cumsum_bufs + sub_unroll))
160 return cumsum_bufs;
161
162 // round to sub_unroll multipl
163 int r_for_sub_token = r - cumsum_bufs;
164 r_for_sub_token = r_for_sub_token / sub_unroll * sub_unroll;
165 int r_token_min = (tokens_ + sub_unroll - 1) / sub_unroll * sub_unroll;
166 r_for_sub_token = min(r_for_sub_token, r_token_min);
167
168 // final check, but usually should not happen
169 if( ((r_for_sub_token + cumsum_bufs) * smem_cols * target_occupancy_ ) > total_ ) {
170 throw std::runtime_error("can't run this kernel, request LDS over size");
171 }
172
173 return r_for_sub_token + cumsum_bufs;
174 }();
175
176 return ck_tile::make_tuple(smem_rows, smem_cols);
177}
178
179// if return 0 or negative, means LDS is not enough
180CK_TILE_HOST index_t moe_sorting_get_sub_token(int tokens_, int num_experts_)
181{
182 auto [r_, c_] = moe_sorting_get_smem_row_col(tokens_, num_experts_);
183 auto sub_token_ = r_ - 2;
184 (void) c_;
185 return sub_token_;
186}
187
189{
190 const void* p_topk_ids; // [token, topk]
191 const void* p_weights; // [token, topk]
192
193 const void* p_local_expert_mask; // [experts]
194 const void* p_local_tokens; // [1] if not nullptr, tokens read from here
195
199 void* p_total_tokens_post_pad; // [2], [0]:outputed tokens_post_padded, [1]:actual tokens on current rank (local_tokens or tokens)
200 // we fused the setzero of output of fused-moe buffer
201 // set this pointer to nullptr will skip this operation
203 void* p_ws; // size is moe_sorting_get_workspace_size()
204 // if return zero, then could be nullptr
205 // must be cleard before use
206 index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens used for ws/LDS calculation
207 index_t unit_size; // this is the M_a of fused-moe kernel
210#if MOE_SORTING_FMOE_2D_BUF
211 // NOTE:
212 // moe_buf_* is a 2d ws buffer used for the following fmoe kernel
213 // arranged as row*col, where row=tokens(or local_token), col=interm_dim
214 // we fuse this clearing inside sorting kernel
215 // Besides, we require inter_dim to be multiple of 16 byte(make sure when alloc ws for fmoe)
216 index_t moe_buf_interm_dim; // p_moe_buf interm_dim
217 index_t moe_buf_elem_bytes; // p_moe_buf byte size(8bit, 16bit, 32bit, etc.)
218#else
219 long_index_t moe_buf_bytes; // byte size of p_moe_buf
220#endif
221
222};
223
224template <typename Problem_>
226{
228
229 using IndexType = typename Problem::IndexType;
230 using WeightType = typename Problem::WeightType;
231
233
235
236 static constexpr index_t kBlockSize = 256;
237 static constexpr index_t OCCUPANCY = 2; // hard coded
238
239 struct Kargs
240 {
241 const void* p_topk_ids;
242 const void* p_weights;
244 const void* p_local_tokens; // [1] if not nullptr, tokens read from here
252#if MOE_SORTING_FMOE_2D_BUF
253 index_t moe_buf_interm_dim; // p_moe_buf interm_dim
254 index_t moe_buf_elem_bytes; // p_moe_buf byte size(8bit, 16bit, 32bit, etc.)
255#else
257#endif
263 // mdiv sub_tokens_mdiv;
264 };
265
266 CK_TILE_HOST static constexpr auto get_num_cu()
267 {
268 index_t num_cu = [&]() {
269 hipDeviceProp_t dev_prop;
270 hipDevice_t dev;
271 HIP_CHECK_ERROR(hipGetDevice(&dev));
272 HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
273 return dev_prop.multiProcessorCount;
274 }();
275 return num_cu;
276 }
277
278 CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
279 {
280#if MOE_SORTING_FMOE_2D_BUF
281 (void)h;
282 return get_num_cu() * OCCUPANCY;
283#else
284 // TODO: assume num-experts not too much
285 return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, BlockSize(h).x * 16));
286#endif
287 }
288
289 CK_TILE_HOST static constexpr auto BlockSize(const Hargs& h)
290 {
291#if MOE_SORTING_USE_EX_KERNEL
292 (void)h;
293 return dim3(256);
294#else
296#endif
297 }
298
299 // in byte
300 CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h)
301 {
302#if MOE_SORTING_USE_EX_KERNEL
303 auto [smem_rows, smem_cols] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts);
304 return smem_rows * smem_cols * sizeof(index_t);
305#else
306 const auto blocks = BlockSize(h);
307 // usually num_experts is power of 2, we pad 1 dword here for the row-size
308 return ((blocks.x + 1) * (h.num_experts + 1) + (h.num_experts + 1)) * sizeof(index_t);
309#endif
310 }
311
312 CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
313 {
314 Kargs k;
316 k.p_weights = h.p_weights;
322 k.p_moe_buf = h.p_moe_buf;
324 k.tokens = h.tokens;
326#if MOE_SORTING_FMOE_2D_BUF
327 k.moe_buf_interm_dim = h.moe_buf_interm_dim;
328 k.moe_buf_elem_bytes = h.moe_buf_elem_bytes;
329#else
331#endif
332
333 const auto blocks = BlockSize(h);
334 // NOTE: tokens could from p_local_tokens, so here this variable is useless
335 // hence moe_align_block_size_kernel() will not behavior properly if we have dynamic tokens
336 // (indeed we can deprecate moe_align_block_size_kernel)
338 k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
339 k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
340 // NOTE: tokens could from p_local_tokens, so here the LDS will be bigger than expected (but works)
341 k.smem_rows = [&](){
342 auto [r_, c_] = moe_sorting_get_smem_row_col(h.tokens, h.num_experts);
343 (void) c_;
344 return r_;
345 }();
346 k.expert_mdiv = mdiv{static_cast<uint32_t>(h.num_experts)};
347 // k.sub_tokens_mdiv = mdiv{static_cast<uint32_t>(k.smem_rows - 1)};
348 return k;
349 }
350
351 // [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
352 // NOTE: wave_size need at least be 16!! dpp 16 is one row
353 template <typename data_t, int wave_size>
354 __device__ inline void wave_cumsum(data_t& thread_data) const
355 {
356 // wave_size must be power of 2
357 constexpr int row_mask = 0xf;
358 constexpr int bank_mask = 0xf;
359 constexpr bool bound_ctrl = true; // ! out-of-bound is zero !
360 auto reduce_op = [&](auto x_, auto y_) { return x_ + y_; };
361
362 if constexpr(wave_size > 1)
363 {
364 thread_data = reduce_op(
365 thread_data,
366 __builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
367 0x111,
368 row_mask,
369 bank_mask,
370 bound_ctrl))); // row_shr:1
371 }
372
373 if constexpr(wave_size > 2)
374 {
375 thread_data = reduce_op(
376 thread_data,
377 __builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
378 0x112,
379 row_mask,
380 bank_mask,
381 bound_ctrl))); // row_shr:2
382 }
383 if constexpr(wave_size > 4)
384 {
385 thread_data =
386 reduce_op(thread_data,
387 __builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
388 0x114,
389 row_mask,
390 bank_mask,
391 bound_ctrl))); // row_shr:4
392 }
393 if constexpr(wave_size == 8) {
394
395 // wave-size=8 need one extra shift
396 thread_data =
397 reduce_op(thread_data,
398 __builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
399 0x118,
400 row_mask,
401 bank_mask,
402 bound_ctrl))); // row_shr:8
403#if CK_TILE_HAS_ROW_NEWBCAST
404 data_t xxx =__builtin_bit_cast(data_t,
405 __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
406 0x157,
407 row_mask,
408 bank_mask,
409 bound_ctrl)); // row_newbcast:7
410
411 data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx;
412 thread_data = thread_data - yyy;
413#else
414 // portable fallback for gfx908 and older: emulate row_newbcast:7 via ds_bpermute
415 // For wave_size == 8 context, we need to broadcast from lane 7 of the 16-lane group
416 int broadcast_src_lane = (__lane_id() & ~15) + 7; // Lane 7 of the 16-lane group
417 int broadcast_addr = broadcast_src_lane << 2; // Convert to byte address
418 int bcast7 = __builtin_amdgcn_ds_bpermute(broadcast_addr, __builtin_bit_cast(int, thread_data));
419
420 // Apply subtraction only to odd 8-lane groups (lanes 8-15 of each 16-lane unit)
421 if ((__lane_id() / 8) % 2 != 0) { // Note: != 0, not == 0
422 thread_data = thread_data - __builtin_bit_cast(data_t, bcast7);
423 }
424#endif
425
426 }
427 if constexpr(wave_size > 8)
428 {
429 thread_data =
430 reduce_op(thread_data,
431 __builtin_bit_cast(data_t, __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
432 0x118,
433 row_mask,
434 bank_mask,
435 bound_ctrl))); // row_shr:8
436 }
437
438 if constexpr(wave_size > 16)
439 {
440 // now row-0, row-0+row-1, row-1+row-2, row-2+row-3
441 int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 1) << 2, __builtin_bit_cast(int, thread_data));
442 v_remote_tmp = __lane_id() >= 16 ? v_remote_tmp : 0;
443 thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp));
444 }
445
446 if constexpr(wave_size > 32)
447 {
448 // lane-id 48...63->31
449 int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 17) << 2, __builtin_bit_cast(int, thread_data));
450 v_remote_tmp = __lane_id() >= 32 ? v_remote_tmp : 0;
451 thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp));
452 }
453 }
454
455 // reduce single pixel within a wave
456 template <typename T, typename F, index_t wave_size_ = get_warp_size()>
457 __device__ static constexpr T wave_reduce(T local, F reduce_f, number<wave_size_> = {})
458 {
459 // constexpr int wave_size = 64;
460 // constexpr int reduce_stage = 6; // 1<<6=64
461 // clang-format off
462 constexpr int reduce_stage = [](){
463 if constexpr(wave_size_ == 2) return 1;
464 else if constexpr(wave_size_ == 4) return 2;
465 else if constexpr(wave_size_ == 8) return 3;
466 else if constexpr(wave_size_ == 16) return 4;
467 else if constexpr(wave_size_ == 32) return 5;
468 else if constexpr(wave_size_ == 64) return 6;
469 else return 0;
470 }();
471 // clang-format on
472 T v_local = local;
473#pragma unroll reduce_stage
474 for(int i_stage = 0; i_stage < reduce_stage; i_stage++)
475 {
476 int src_lane = __lane_id() ^ (1 << i_stage);
477 int32_t v_remote_tmp =
478 __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
479 T v_remote = bit_cast<T>(v_remote_tmp);
480 v_local = reduce_f(v_local, v_remote);
481 }
482 return v_local;
483 }
484
486 {
487 return row * total_col + col;
488 }
489
491 {
492 const index_t offset = (blockIdx.x - 1) * blockDim.x + threadIdx.x;
493 if(offset < buf_bytes / 16)
494 {
495 buf[offset] = uint8x16_t{0};
496 }
497 }
498
499 CK_TILE_DEVICE void
500 moe_buf_set_zero_kernel_2d(void* buf, index_t row, index_t col, index_t elem_bytes) const
501 {
502 const long_index_t total_pixels = static_cast<long_index_t>(row) * col;
503 const long_index_t total_bytes = total_pixels * elem_bytes;
504 const long_index_t total_elems = total_bytes / 16; // always use dwordx4
505
506 using vector_type = ext_vector_t<index_t, 4>;
507 vector_type* p_buf = reinterpret_cast<vector_type*>(buf);
508 auto zero_ = vector_type{0};
509
510 for(long_index_t i = (blockIdx.x - 1) * kBlockSize + threadIdx.x; i < total_elems;
511 i += (gridDim.x - 1) * kBlockSize)
512 {
513 p_buf[i] = zero_;
514 }
515 }
516
517 CK_TILE_DEVICE void moe_align_block_size_kernel(const IndexType* __restrict__ topk_id,
518 const WeightType* __restrict__ weights,
519 index_t* p_sorted_token_ids,
520 WeightType* p_sorted_weights,
521 index_t* p_sorted_expert_ids,
522 index_t* p_total_tokens_post_pad,
523 const index_t num_experts,
524 const index_t tokens_per_thread,
525 const index_t numel,
526 const mdiv unit_size_mdiv,
527 const mdiv topk_mdiv,
528 void* smem) const
529 {
530 const index_t tid = static_cast<index_t>(threadIdx.x);
531 const index_t start_idx = tid * tokens_per_thread;
532
533 index_t* shared_mem = reinterpret_cast<index_t*>(smem);
534
535 index_t* tokens_cnts = shared_mem; // 2d: (blockDim.x + 1, num_experts)
536 index_t* cumsum = shared_mem + (blockDim.x + 1) * (num_experts + 1); // 1: (num_experts + 1)
537
538 for(int i = 0; i < num_experts; ++i)
539 {
540 tokens_cnts[calc_index(num_experts + 1, tid + 1, i)] = 0;
541 }
542
543#pragma unroll Problem_::InternalLoadUnroll
544 for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
545 {
546 ++tokens_cnts[calc_index(num_experts + 1, tid + 1, topk_id[i])];
547 }
548 __syncthreads();
549
550#if MOE_SORTING_FUSE_MP_01
551 if(tid < num_experts)
552 {
553 tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0;
554 index_t local_c[8];
555 index_t prev_c = 0;
556 // TODO: manually unroll. pragma unroll does not work well when we have dependency
557 for(int i = 1; i <= static_cast<index_t>(blockDim.x); i += 8)
558 {
559 local_c[0] = tokens_cnts[calc_index(num_experts + 1, i + 0, tid)];
560 local_c[1] = tokens_cnts[calc_index(num_experts + 1, i + 1, tid)];
561 local_c[2] = tokens_cnts[calc_index(num_experts + 1, i + 2, tid)];
562 local_c[3] = tokens_cnts[calc_index(num_experts + 1, i + 3, tid)];
563 local_c[4] = tokens_cnts[calc_index(num_experts + 1, i + 4, tid)];
564 local_c[5] = tokens_cnts[calc_index(num_experts + 1, i + 5, tid)];
565 local_c[6] = tokens_cnts[calc_index(num_experts + 1, i + 6, tid)];
566 local_c[7] = tokens_cnts[calc_index(num_experts + 1, i + 7, tid)];
567
568 local_c[0] += prev_c;
569 local_c[1] += local_c[0];
570 local_c[2] += local_c[1];
571 local_c[3] += local_c[2];
572 local_c[4] += local_c[3];
573 local_c[5] += local_c[4];
574 local_c[6] += local_c[5];
575 local_c[7] += local_c[6];
576 prev_c = local_c[7];
577
578 tokens_cnts[calc_index(num_experts + 1, i + 0, tid)] = local_c[0];
579 tokens_cnts[calc_index(num_experts + 1, i + 1, tid)] = local_c[1];
580 tokens_cnts[calc_index(num_experts + 1, i + 2, tid)] = local_c[2];
581 tokens_cnts[calc_index(num_experts + 1, i + 3, tid)] = local_c[3];
582 tokens_cnts[calc_index(num_experts + 1, i + 4, tid)] = local_c[4];
583 tokens_cnts[calc_index(num_experts + 1, i + 5, tid)] = local_c[5];
584 tokens_cnts[calc_index(num_experts + 1, i + 6, tid)] = local_c[6];
585 tokens_cnts[calc_index(num_experts + 1, i + 7, tid)] = local_c[7];
586 }
587 }
588#else
589 // TODO: below code still working, but slow in expert=32/topk=5 case. Put here for future
590 // heuristic
591 {
592 if(tid < num_experts)
593 tokens_cnts[calc_index(num_experts + 1, 0, tid)] = 0;
594 for(int i = 0; i < num_experts; i += 8)
595 {
596 index_t local_c[8];
597#pragma unroll
598 for(int j = 0; j < 8; j++)
599 {
600 local_c[j] = tokens_cnts[calc_index(num_experts + 1, tid + 1, i + j)];
601 }
602
603#pragma unroll
604 for(int j = 0; j < 8; j++)
605 {
606 wave_cumsum<int, 64>(local_c[j]);
607 }
608
609#pragma unroll
610 for(int j = 0; j < 8; j++)
611 {
612 tokens_cnts[calc_index(num_experts + 1, tid + 1, i + j)] = local_c[j];
613 }
614 }
615 }
616#endif
617
618 __syncthreads();
619 if constexpr(Problem::ExpertTile == 0)
620 {
621 if(tid == 0)
622 {
623 cumsum[0] = 0;
624 for(int i = 1; i <= num_experts; ++i)
625 {
626 auto current_units = [&]() {
627 index_t x_ = tokens_cnts[calc_index(num_experts + 1, blockDim.x, i - 1)] +
628 unit_size_mdiv.divisor - 1;
629 index_t y_ = unit_size_mdiv.div(x_);
630 return max(y_, 1) * unit_size_mdiv.divisor;
631 }();
632 cumsum[i] = cumsum[i - 1] + current_units;
633 }
634 *p_total_tokens_post_pad = cumsum[num_experts];
635 }
636 }
637 else
638 {
639 // TODO: we have out-of-bound read here. But result is still OK (will ignore tid >=
640 // expert) for simplicity, not check experts here.
641 int local_cnt = tokens_cnts[calc_index(num_experts + 1, blockDim.x, tid)];
642 int blocks_pers_expert = unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1);
643 int padded_tokens_per_expert = max(blocks_pers_expert, 1) * unit_size_mdiv.divisor;
644 int local_cumsum = padded_tokens_per_expert;
645 wave_cumsum<int, 64>(local_cumsum);
646
647 if(tid == (num_experts - 1))
648 {
649 cumsum[0] = 0;
650 *p_total_tokens_post_pad = local_cumsum;
651 }
652 if(tid < num_experts)
653 {
654 cumsum[tid + 1] = local_cumsum;
655 }
656 }
657
658 __syncthreads();
659 if(tid < num_experts)
660 {
661 int e_start = cumsum[tid];
662 int e_end = cumsum[tid + 1];
663 for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
664 {
665 p_sorted_expert_ids[unit_size_mdiv.div(i)] = tid;
666 }
667 }
668
669#pragma unroll Problem_::InternalLoadUnroll
670 for(int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i)
671 {
672 index_t expert_id = topk_id[i];
673 index_t local_cnt = tokens_cnts[calc_index(num_experts + 1, tid, expert_id)];
674 index_t rank_post_pad = local_cnt + cumsum[expert_id];
675#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
676 uint32_t curr_token_id, curr_topk_id;
677 topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
678 p_sorted_token_ids[rank_post_pad] = MOE_SORTING_MOCK_ID(curr_token_id, curr_topk_id);
679#else
680 p_sorted_token_ids[rank_post_pad] = topk_mdiv.div(i);
681#endif
682 p_sorted_weights[rank_post_pad] = weights[i];
683 tokens_cnts[calc_index(num_experts + 1, tid, expert_id)] = local_cnt + 1;
684 }
685
686 if constexpr(Problem::ExpertTile == 0)
687 {
688 const index_t prefill_token = topk_mdiv.div(numel);
689 if(tid < num_experts)
690 {
691 index_t expert_offset =
692 cumsum[tid] + tokens_cnts[calc_index(num_experts + 1, blockDim.x, tid)];
693 index_t expert_end = cumsum[tid + 1];
694 while(expert_offset < expert_end)
695 {
696#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
697 p_sorted_token_ids[expert_offset] =
698 MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor);
699#else
700 p_sorted_token_ids[expert_offset] = prefill_token;
701#endif
702 p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
703 expert_offset++;
704 }
705 }
706 }
707 else
708 {
709 const index_t prefill_token = topk_mdiv.div(numel);
710 // TODO: only support expert-tile like 8, 16, 32
711 static constexpr index_t experts_per_wave = get_warp_size() / Problem::ExpertTile;
712 {
713 index_t eid = tid / experts_per_wave;
714 index_t expert_offset = cumsum[eid] +
715 tokens_cnts[calc_index(num_experts + 1, blockDim.x, eid)] +
716 tid % experts_per_wave;
717 index_t expert_end = cumsum[eid + 1];
718 if(eid < num_experts)
719 {
720 while(expert_offset < expert_end)
721 {
722#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
723 p_sorted_token_ids[expert_offset] =
724 MOE_SORTING_MOCK_ID(prefill_token, topk_mdiv.divisor);
725#else
726 p_sorted_token_ids[expert_offset] = prefill_token;
727#endif
728 p_sorted_weights[expert_offset] = static_cast<WeightType>(0.0);
729 expert_offset += experts_per_wave;
730 }
731 }
732 }
733 }
734 }
735
736 // only support index_t, and single pixel access
738 {
741
742 // this is 2D
744 : smem(smem_), row_stride(row_stride_)
745 {
746 }
748 {
749 return smem[i_row * row_stride + i_col];
750 }
752 {
753 return smem[i_row * row_stride + i_col];
754 }
755
756 // this is 1D or linear
758 CK_TILE_DEVICE const index_t& operator()(index_t idx) const { return smem[idx]; }
760 };
761
762 CK_TILE_DEVICE void
763 moe_align_block_size_kernel_ex(const IndexType* __restrict__ topk_id,
764 const WeightType* __restrict__ weights,
765 const IndexType* __restrict__ local_expert_mask,
766 index_t* p_sorted_token_ids,
767 WeightType* p_sorted_weights,
768 index_t* p_sorted_expert_ids,
769 index_t* p_total_tokens_post_pad,
770 const index_t num_experts,
771 const index_t tokens,
772 const mdiv unit_size_mdiv,
773 const mdiv topk_mdiv,
774 const mdiv expert_mdiv,
775 const index_t smem_rows,
776 void* smem) const
777 {
778 const index_t tid = static_cast<index_t>(threadIdx.x);
780 const index_t lid = __lane_id();
781 constexpr index_t block_size = 256; // blockDim.x;
782 const index_t sub_tokens = smem_rows - 2; // sub_tokens_mdiv.divisor;
783 const index_t topk = topk_mdiv.divisor;
784 auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
785
786 const index_t smem_cols = num_experts + 1;
787
788 simple_smem_indexer smem_cumsum{reinterpret_cast<index_t*>(smem) + 0};
789 simple_smem_indexer smem_cumdup{reinterpret_cast<index_t*>(smem) + smem_cols};
790 simple_smem_indexer smem_tokens{reinterpret_cast<index_t*>(smem) + 2 * smem_cols,
791 smem_cols};
792
793 // #pragma unroll 8
794 for(int i = tid; i < (sub_tokens * num_experts); i += block_size)
795 {
796 uint32_t curr_token_id, curr_expert_id;
797 expert_mdiv.divmod(i, curr_token_id, curr_expert_id);
798 smem_tokens(curr_token_id, curr_expert_id) = 0;
799 }
800 __syncthreads();
801
802 for(int i_token = 0; i_token < tokens; i_token += sub_tokens)
803 {
804 // NOTE: below for loop can't have barrier inside!!
805 for(int i = tid; i < (sub_tokens * topk); i += block_size)
806 {
807 uint32_t curr_token_id, curr_topk_id;
808 topk_mdiv.divmod(i, curr_token_id, curr_topk_id);
809 int i_t = i_token + curr_token_id;
810
811 if(i_t < tokens)
812 {
813 int eid = topk_id[i_t * topk + curr_topk_id];
814
815 if constexpr(Problem::SubTokenOneShot)
816 smem_tokens(curr_token_id, eid) = curr_topk_id + 1;
817 else
818 smem_tokens(curr_token_id, eid)++;
819 }
821 }
822 __syncthreads(); // make sure different i_token iteration not overlap by different wave
823 }
824
825 // counting
826 if(tid == 0)
827 {
828 smem_cumsum(0) = 0;
829 // smem_cumdup(0) = 0;
830 }
831
832 {
833 constexpr int lane_group_sz = 8;
834 int lane_group_id = tid / lane_group_sz;
835 int lane_group_os = tid % lane_group_sz;
836 constexpr int lane_group_nm = block_size / lane_group_sz;
837
838 for(int i_e = lane_group_id; i_e < num_experts; i_e += lane_group_nm)
839 {
840 index_t local_c[Problem::SubTokenTile];
841 index_t cnt = 0;
842
843 for(int i = 0; i < sub_tokens; i += 8 * Problem::SubTokenTile)
844 {
845#pragma unroll Problem::SubTokenTile
846 for(int j = 0; j < Problem::SubTokenTile; j++)
847 {
848 local_c[j] = smem_tokens(i + j * 8 + lane_group_os, i_e);
849 if constexpr(Problem::SubTokenOneShot)
850 {
851 local_c[j] = local_c[j] != 0 ? 1 : 0;
852 }
853 }
854
855#pragma unroll Problem::SubTokenTile
856 for(int j = 0; j < Problem::SubTokenTile; j++)
857 {
858 cnt += wave_reduce(local_c[j], f_sum, number<8>{});
859 }
860 }
861 if(lane_group_os == 0)
862 smem_cumsum(i_e + 1) = cnt;
863 }
864 }
865
866 if constexpr(Problem::LocalExpertMasking)
867 {
868 smem_cumdup(0) = 0;
869 for(int i_e = tid; i_e < num_experts; i_e += block_size)
870 {
871 // reuse this buffer
872 smem_cumdup(i_e + 1) = local_expert_mask[i_e];
873 }
874 }
875
876 __syncthreads();
877
878 {
879 if(wid == 0)
880 {
881 // NOTE: under this block can never use __syncthreads!
882 int i_e_ = 0;
883 int local_cumsum_ = 0;
884 for(; i_e_ < num_experts; i_e_ += get_warp_size())
885 {
886 int pre_cumsum_ = smem_cumsum(lid == 0 ? i_e_ : 0);
887 int local_cnt = smem_cumsum(i_e_ + lid + 1);
888 int blocks_pers_expert =
889 unit_size_mdiv.div(local_cnt + unit_size_mdiv.divisor - 1);
890
891 int pre_cumsum_masking = [&]() {
892 if constexpr(Problem::LocalExpertMasking)
893 return smem_cumdup(lid == 0 ? i_e_ : 0);
894 else
895 return 0; // not used
896 }();
897 int local_masking = [&]() {
898 if constexpr(Problem::LocalExpertMasking)
899 return smem_cumdup(i_e_ + lid + 1);
900 else
901 return 0; // not used
902 }();
903 int padded_tokens_per_expert = [&]() {
904 int x_ = [&]() {
905 if constexpr(Problem::SkipExpertsWithZeroTokens)
906 {
907 // if local_cnt is zero, blocks_pers_expert will be zero
908 // this is what we want to achieve
909 return blocks_pers_expert * unit_size_mdiv.divisor;
910 }
911 else
912 {
913 return max(blocks_pers_expert, 1) * unit_size_mdiv.divisor;
914 }
915 }();
916 if constexpr(Problem::LocalExpertMasking)
917 {
918 return local_masking ? x_ : 0;
919 }
920 else
921 return x_;
922 }();
923
924 local_cumsum_ = padded_tokens_per_expert;
925 local_cumsum_ += pre_cumsum_; // note pre_cumsum must be added after local
926 // cumsum padded in case local cumsum is zero, but
927 // pre_sumsum has value, which will result int
928 // zero local cumsum(but we want at least padded)
930
931 if((i_e_ + lid) < num_experts)
932 smem_cumsum(i_e_ + lid + 1) = local_cumsum_;
933
934 if constexpr(Problem::LocalExpertMasking)
935 {
936 local_masking += pre_cumsum_masking;
938 if((i_e_ + lid) < num_experts)
939 smem_cumdup(i_e_ + lid + 1) = local_masking;
940 }
941
942 // NOTE: this waitcnt is a must, compiler will not generate waitcnt lgkmcnt()
943 // for above write however __syncthreads will cause barrier with waves other
944 // than 0(which is not we want)
946 }
947 if((lid + i_e_ - get_warp_size()) == (num_experts - 1))
948 {
949 *p_total_tokens_post_pad = local_cumsum_;
950 p_total_tokens_post_pad[1] = tokens;
951 }
952 }
953 __syncthreads();
954 }
955
956 for(int i_e = tid; i_e < num_experts; i_e += block_size)
957 {
958 int e_start = smem_cumsum(i_e);
959 int e_end = smem_cumsum(i_e + 1);
960
961 int expert_id = [&]() {
962 if constexpr(Problem::LocalExpertMasking)
963 {
964 // local expert id from cumsum
965 return smem_cumdup(i_e);
966 }
967 else
968 return i_e;
969 }();
970
971 smem_cumdup(i_e) = e_start; // duplicate cumsum for later use
972 if constexpr(Problem::SkipExpertsWithZeroTokens)
973 {
974 if(e_start == e_end) // skip zero token expert
975 continue;
976 }
977
978 if constexpr(Problem::LocalExpertMasking)
979 {
980 if(local_expert_mask[i_e] == 0)
981 continue;
982 }
983
984 for(int i = e_start; i < e_end; i += unit_size_mdiv.divisor)
985 {
986 p_sorted_expert_ids[unit_size_mdiv.div(i)] = expert_id;
987 }
988 }
989 smem_cumdup(num_experts) = smem_cumsum(num_experts);
990
991 // fill the p_sorted_token_ids/p_sorted_weights
992 for(int i_token = 0; i_token < tokens; i_token += sub_tokens)
993 {
994 if constexpr(!Problem::SubTokenOneShot)
995 {
996 // clear every time
997 for(int i = tid; i < (sub_tokens * num_experts); i += block_size)
998 {
999 uint32_t curr_token_id, curr_expert_id;
1000 expert_mdiv.divmod(i, curr_token_id, curr_expert_id);
1001 smem_tokens(curr_token_id, curr_expert_id) = 0;
1002 }
1003 __syncthreads();
1004
1005 // load again
1006 for(int i = tid; i < (sub_tokens * topk); i += block_size)
1007 {
1008 uint32_t curr_token_id_, curr_topk_id_;
1009 topk_mdiv.divmod(i, curr_token_id_, curr_topk_id_);
1010 int curr_token_id = static_cast<int>(curr_token_id_);
1011 int curr_topk_id = static_cast<int>(curr_topk_id_);
1012 int i_t = i_token + curr_token_id;
1013 if(i_t < tokens)
1014 {
1015 int eid = topk_id[i_t * topk + curr_topk_id];
1016 smem_tokens(curr_token_id, eid) = curr_topk_id + 1; // at least 1
1017 }
1018 }
1019 __syncthreads();
1020 }
1021
1022 {
1023 constexpr int lane_group_sz = 8;
1024 int lane_group_id = tid / lane_group_sz;
1025 int lane_group_os = tid % lane_group_sz;
1026 constexpr int lane_group_nm = block_size / lane_group_sz;
1027 for(int eid = lane_group_id; eid < num_experts; eid += lane_group_nm)
1028 {
1029 if constexpr(Problem::LocalExpertMasking)
1030 {
1031 if(local_expert_mask[eid] == 0)
1032 continue;
1033 }
1034 int position = smem_cumsum(eid);
1035 for(int i_sub_token = lane_group_os; i_sub_token < sub_tokens;
1036 i_sub_token += lane_group_sz)
1037 {
1038 auto x = smem_tokens(i_sub_token, eid);
1039
1040 int local_cnt_cache = x != 0 ? 1 : 0;
1041 int local_cnt = local_cnt_cache;
1043 if(x != 0)
1044 {
1045 // now x is topk value
1046#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
1047 p_sorted_token_ids[position + local_cnt - 1] =
1048 MOE_SORTING_MOCK_ID(i_token + i_sub_token, x - 1);
1049#else
1050 p_sorted_token_ids[position + local_cnt - 1] = i_token + i_sub_token;
1051#endif
1052 p_sorted_weights[position + local_cnt - 1] =
1053 weights[(i_token + i_sub_token) * topk + x - 1];
1054 }
1055
1056 int remote_cnt = __builtin_amdgcn_ds_bpermute(
1057 (lane_group_sz * (lane_group_id + 1) - 1) << 2, local_cnt);
1058
1059 position += remote_cnt;
1060 }
1061 smem_cumsum(eid) = position;
1062 }
1063 }
1064 __syncthreads();
1065 }
1066
1067 // add the skip number
1068 for(int eid = tid; eid < num_experts; eid += block_size)
1069 {
1070 int e_start = smem_cumsum(eid);
1071 int e_end = smem_cumdup(eid + 1);
1072 if constexpr(Problem::SkipExpertsWithZeroTokens)
1073 {
1074 if(e_start == e_end) // skip zero token expert
1075 continue;
1076 }
1077 while(e_start < e_end)
1078 {
1079#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
1080 p_sorted_token_ids[e_start] = MOE_SORTING_MOCK_ID(tokens, topk);
1081#else
1082 p_sorted_token_ids[e_start] = tokens;
1083#endif
1084 p_sorted_weights[e_start] = static_cast<WeightType>(0.0);
1085 e_start++;
1086 }
1087 }
1088 }
1089
1091 {
1092 index_t tokens_ = [&]() {
1093 if constexpr(Problem::LocalToken)
1094 {
1095 return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
1096 }
1097 else
1098 {
1099 return kargs.tokens;
1100 }
1101 }();
1102
1103 if(blockIdx.x > 0)
1104 {
1105 if(kargs.p_moe_buf)
1106 {
1107#if MOE_SORTING_FMOE_2D_BUF
1109 kargs.p_moe_buf, tokens_, kargs.moe_buf_interm_dim, kargs.moe_buf_elem_bytes);
1110#else
1111 moe_buf_set_zero_kernel(reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
1112 kargs.moe_buf_bytes);
1113#endif
1114 }
1115 return;
1116 }
1117
1118 extern __shared__ char smem[];
1119
1120#if MOE_SORTING_USE_EX_KERNEL
1122 static_cast<const IndexType*>(kargs.p_topk_ids),
1123 static_cast<const WeightType*>(kargs.p_weights),
1124 static_cast<const IndexType*>(kargs.p_local_expert_mask),
1125 static_cast<IndexType*>(kargs.p_sorted_token_ids),
1126 static_cast<WeightType*>(kargs.p_sorted_weights),
1127 static_cast<IndexType*>(kargs.p_sorted_expert_ids),
1128 static_cast<IndexType*>(kargs.p_total_tokens_post_pad),
1129 kargs.num_experts,
1130 tokens_,
1131 kargs.unit_size_mdiv,
1132 kargs.topk_mdiv,
1133 kargs.expert_mdiv,
1134 kargs.smem_rows,
1135 smem);
1136#else
1137 const size_t numel = kargs.tokens * kargs.topk_mdiv.divisor;
1138 return moe_align_block_size_kernel(static_cast<const IndexType*>(kargs.p_topk_ids),
1139 static_cast<const WeightType*>(kargs.p_weights),
1140 static_cast<IndexType*>(kargs.p_sorted_token_ids),
1141 static_cast<WeightType*>(kargs.p_sorted_weights),
1142 static_cast<IndexType*>(kargs.p_sorted_expert_ids),
1143 static_cast<IndexType*>(kargs.p_total_tokens_post_pad),
1144 kargs.num_experts,
1145 kargs.tokens_per_thread,
1146 numel,
1147 kargs.unit_size_mdiv,
1148 kargs.topk_mdiv,
1149 smem);
1150#endif
1151 }
1152};
1153
1154namespace impl {
1155
1156// [expert, padded_tokens]
1158{
1159 // Pad to multiply of 32. This can make sure even if the mesh is in 8bit,
1160 // we can still use dwordx4 load/store
1161 constexpr index_t chunk = 32;
1162 return (tokens + chunk - 1) / chunk * chunk;
1163};
1164
1165// 4-i32 mesh, 2-i16 mseh, 1-i8 mesh
1167 index_t /*num_experts_*/,
1168 index_t topk_)
1169{
1170 // small token case, let's run mesh with dword score board
1171 if(tokens_ < 512)
1172 return 4;
1173 else
1174 {
1175 if(topk_ >= 255)
1176 return 2; // 16bit mesh
1177 else
1178 return 1; // 8bit mesh if small enough
1179 }
1180}
1181
1183 index_t num_experts,
1184 index_t topk)
1185{
1186 index_t row_size = moe_sorting_mp_mesh_stride(tokens);
1187 index_t elem = num_experts * row_size;
1188 return elem * moe_sorting_mesh_byte_size(tokens, num_experts, topk);
1189};
1190
1192{
1193 constexpr index_t chunk = 32;
1194 index_t row_size = num_experts + 1;
1195 return (row_size + chunk - 1) / chunk * chunk * sizeof(index_t);
1196};
1197
1199{
1200 constexpr index_t chunk = 32;
1201 return chunk * sizeof(index_t);
1202};
1203
1204template <typename T, typename F, index_t wave_size_ = get_warp_size()>
1205CK_TILE_DEVICE constexpr T moe_sorting_wave_reduce(T local, F reduce_f, number<wave_size_> = {})
1206{
1207 // constexpr int wave_size = 64;
1208 // constexpr int reduce_stage = 6; // 1<<6=64
1209 // clang-format off
1210 constexpr int reduce_stage = [](){
1211 if constexpr(wave_size_ == 2) return 1;
1212 else if constexpr(wave_size_ == 4) return 2;
1213 else if constexpr(wave_size_ == 8) return 3;
1214 else if constexpr(wave_size_ == 16) return 4;
1215 else if constexpr(wave_size_ == 32) return 5;
1216 else if constexpr(wave_size_ == 64) return 6;
1217 else return 0;
1218 }();
1219 // clang-format on
1220 T v_local = local;
1221#pragma unroll reduce_stage
1222 for(int i_stage = 0; i_stage < reduce_stage; i_stage++)
1223 {
1224 int src_lane = __lane_id() ^ (1 << i_stage);
1225 int32_t v_remote_tmp =
1226 __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
1227 T v_remote = bit_cast<T>(v_remote_tmp);
1228 v_local = reduce_f(v_local, v_remote);
1229 }
1230 return v_local;
1231}
1232
1233// [a, b, c, d....] -> [a, a+b, a+b+c, a+b+c+d, ....]
1234// NOTE: wave_size need at least be 16!! dpp 16 is one row
1235template <typename data_t, int wave_size>
1237{
1238 // wave_size must be power of 2
1239 constexpr int row_mask = 0xf;
1240 constexpr int bank_mask = 0xf;
1241 constexpr bool bound_ctrl = true; // ! out-of-bound is zero !
1242 auto reduce_op = [&](auto x_, auto y_) { return x_ + y_; };
1243
1244 if constexpr(wave_size > 1)
1245 {
1246 thread_data = reduce_op(
1247 thread_data,
1248 __builtin_bit_cast(data_t,
1249 __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
1250 0x111,
1251 row_mask,
1252 bank_mask,
1253 bound_ctrl))); // row_shr:1
1254 }
1255
1256 if constexpr(wave_size > 2)
1257 {
1258 thread_data = reduce_op(
1259 thread_data,
1260 __builtin_bit_cast(data_t,
1261 __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
1262 0x112,
1263 row_mask,
1264 bank_mask,
1265 bound_ctrl))); // row_shr:2
1266 }
1267 if constexpr(wave_size > 4)
1268 {
1269 thread_data = reduce_op(
1270 thread_data,
1271 __builtin_bit_cast(data_t,
1272 __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
1273 0x114,
1274 row_mask,
1275 bank_mask,
1276 bound_ctrl))); // row_shr:4
1277 }
1278 if constexpr(wave_size == 8)
1279 {
1280
1281 // wave-size=8 need one extra shift
1282 thread_data = reduce_op(
1283 thread_data,
1284 __builtin_bit_cast(data_t,
1285 __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
1286 0x118,
1287 row_mask,
1288 bank_mask,
1289 bound_ctrl))); // row_shr:8
1290#if CK_TILE_HAS_ROW_NEWBCAST
1291 data_t xxx =
1292 __builtin_bit_cast(data_t,
1293 __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
1294 0x157,
1295 row_mask,
1296 bank_mask,
1297 bound_ctrl)); // row_newbcast:7
1298
1299 data_t yyy = (__lane_id() / 8) % 2 == 0 ? 0 : xxx;
1300 thread_data = thread_data - yyy;
1301#else
1302 // portable fallback for gfx908 and older: emulate row_newbcast:7 via ds_bpermute
1303 // For wave_size == 8 context, we need to broadcast from lane 7 of the 16-lane group
1304 int broadcast_src_lane = (__lane_id() & ~15) + 7; // Lane 7 of the 16-lane group
1305 int broadcast_addr = broadcast_src_lane << 2; // Convert to byte address
1306 int bcast7 =
1307 __builtin_amdgcn_ds_bpermute(broadcast_addr, __builtin_bit_cast(int, thread_data));
1308
1309 // Apply subtraction only to odd 8-lane groups (lanes 8-15 of each 16-lane unit)
1310 if((__lane_id() / 8) % 2 != 0)
1311 { // Note: != 0, not == 0
1312 thread_data = thread_data - __builtin_bit_cast(data_t, bcast7);
1313 }
1314#endif
1315 }
1316 if constexpr(wave_size > 8)
1317 {
1318 thread_data = reduce_op(
1319 thread_data,
1320 __builtin_bit_cast(data_t,
1321 __builtin_amdgcn_mov_dpp(__builtin_bit_cast(int, thread_data),
1322 0x118,
1323 row_mask,
1324 bank_mask,
1325 bound_ctrl))); // row_shr:8
1326 }
1327
1328 if constexpr(wave_size > 16)
1329 {
1330 // now row-0, row-0+row-1, row-1+row-2, row-2+row-3
1331 int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 1) << 2,
1332 __builtin_bit_cast(int, thread_data));
1333 v_remote_tmp = __lane_id() >= 16 ? v_remote_tmp : 0;
1334 thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp));
1335 }
1336
1337 if constexpr(wave_size > 32)
1338 {
1339 // lane-id 48...63->31
1340 int v_remote_tmp = __builtin_amdgcn_ds_bpermute(((__lane_id() & 0x30) - 17) << 2,
1341 __builtin_bit_cast(int, thread_data));
1342 v_remote_tmp = __lane_id() >= 32 ? v_remote_tmp : 0;
1343 thread_data = reduce_op(thread_data, __builtin_bit_cast(data_t, v_remote_tmp));
1344 }
1345}
1346
1347template <index_t kBlockSize = 256>
1349{
1350 // const index_t offset = (blockIdx.x - 1) * kBlockSize + threadIdx.x;
1351 long_index_t offset = static_cast<long_index_t>(gid) * kBlockSize + threadIdx.x;
1352 if(offset < buf_bytes / 16)
1353 {
1354 buf[offset] = uint8x16_t{0};
1355 }
1356}
1357
1358template <index_t kBlockSize = 256>
1360 void* buf, index_t row, index_t col, index_t elem_bytes, index_t gid, index_t blocks)
1361{
1362 const long_index_t total_pixels = static_cast<long_index_t>(row) * col;
1363 const long_index_t total_bytes = total_pixels * elem_bytes;
1364 const long_index_t total_elems = total_bytes / 16; // always use dwordx4
1365
1366 using vector_type = ext_vector_t<index_t, 4>;
1367 vector_type* p_buf = reinterpret_cast<vector_type*>(buf);
1368 auto zero_ = vector_type{0};
1369
1370 for(long_index_t i = gid * kBlockSize + threadIdx.x; i < total_elems; i += blocks * kBlockSize)
1371 {
1372 p_buf[i] = zero_;
1373 }
1374}
1375
1376} // namespace impl
1377
1378// TODO: tokens could be from
1379// prefer to run mp kernel if is not oneshot
1380CK_TILE_HOST bool moe_sorting_is_oneshot(int tokens_, int num_experts_)
1381{
1382#if CK_TILE_WA_ISSUE_2028
1383 if(tokens_ >= 65536 * 2)
1384 {
1385 return true;
1386 }
1387#endif
1388 auto sub_token_ = moe_sorting_get_sub_token(tokens_, num_experts_);
1389 bool is_sub_token_onshot = tokens_ <= sub_token_;
1390 return is_sub_token_onshot;
1391}
1392
1393// return size in byte
1394CK_TILE_HOST index_t moe_sorting_mp_get_workspace_size(int tokens_, int num_experts_, int topk_)
1395{
1396 index_t s_ = impl::moe_sorting_mp_mesh_smem_size(tokens_, num_experts_, topk_) +
1398#if MOE_SORTING_FUSE_MP_01
1400#else
1401 ;
1402#endif
1403 return s_;
1404}
1405
1406// return size in byte
1407// dispatch_policy: 0-automatically pick up kerel. 1-always use single kernel, 2-always use mp
1408// kernel
1410 int num_experts_,
1411 int topk_,
1412 int dispatch_policy_)
1413{
1414#if 1
1415 // return 0;
1416 if(dispatch_policy_ == 0)
1417 {
1418 if(moe_sorting_is_oneshot(tokens_, num_experts_))
1419 {
1420 return 0;
1421 }
1422 else
1423 {
1424 return moe_sorting_mp_get_workspace_size(tokens_, num_experts_, topk_);
1425 }
1426 }
1427 else if(dispatch_policy_ == 1)
1428 {
1429 return 0; // always use single kernel
1430 }
1431 else
1432 {
1433 return moe_sorting_mp_get_workspace_size(tokens_, num_experts_, topk_);
1434 }
1435#else
1436 return moe_sorting_mp_get_workspace_size(tokens_, num_experts_, topk_);
1437#endif
1438}
1439
1440template <typename Problem_>
1442{
1444 static constexpr index_t kBlockSize = Problem::BlockSize;
1445 static constexpr index_t OCCUPANCY = Problem::Occu;
1446
1448
1449 struct Kargs
1450 {
1451 const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
1452 void* p_expert_mesh; // [expert, tokens]
1453 index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
1454 // used for ws/LDS calculation
1456 index_t mesh_stride; // mesh_stride for p_expert_mesh
1458 };
1459
1460 CK_TILE_HOST static constexpr auto get_num_cu()
1461 {
1462 index_t num_cu = [&]() {
1463 hipDeviceProp_t dev_prop;
1464 hipDevice_t dev;
1465 HIP_CHECK_ERROR(hipGetDevice(&dev));
1466 HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
1467 return dev_prop.multiProcessorCount;
1468 }();
1469 return num_cu;
1470 }
1471
1472 CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
1473 {
1474 Kargs k;
1476 k.p_expert_mesh = h.p_ws;
1477 k.tokens = h.tokens;
1481 return k;
1482 }
1483
1484 CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; }
1485
1486 CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
1487
1488 // in byte
1489 CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; }
1490
1492 {
1493 index_t tokens = [&]() {
1494 if constexpr(Problem::LocalToken)
1495 {
1496 return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
1497 }
1498 else
1499 {
1500 return kargs.tokens;
1501 }
1502 }();
1503
1504 index_t mesh_stride = [&]() {
1505 if constexpr(Problem::LocalToken)
1506 {
1507 return impl::moe_sorting_mp_mesh_stride(tokens);
1508 }
1509 else
1510 {
1511 return kargs.mesh_stride;
1512 }
1513 }();
1514
1515 index_t row_size = mesh_stride; // impl::moe_sorting_mp_mesh_stride(tokens);
1516 index_t pixels = kargs.num_experts * row_size;
1517 index_t total_bytes = pixels * kargs.mesh_byte_size;
1518 index_t total_elems = total_bytes / 16; // always use dwordx4
1519
1520 using vector_type = ext_vector_t<index_t, 4>;
1521 vector_type* p_expert_mesh = reinterpret_cast<vector_type*>(kargs.p_expert_mesh);
1522 auto zero_ = vector_type{0};
1523
1524 for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < total_elems;
1525 i += gridDim.x * kBlockSize)
1526 {
1527 p_expert_mesh[i] = zero_;
1528 }
1529 }
1530};
1531
1532// below kernel is multi-phase implementation for large token and/or expert case
1533
1534// write into a buffer to record the token cnt
1535// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
1536// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
1537// tok-0 tok-1 tok-2 tok-3 tok-4
1538// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float
1539// number)
1540//
1541// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
1542// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
1543// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
1544/*
1545
1546p_expert_mesh:
1547 t0 t1 t2 t3 t4 r5
1548 +--+--+--+--+--+--+
1549e0 | 1| | | | | |
1550e1 | | | 1| 1| 1| |
1551e2 | | 1| | 1| | |
1552e3 | 1| 1| 1| 1| 1| |
1553e4 | | | | | | |
1554e5 | 1| 1| 1| | | 1|
1555
1556
1557p_expert_cumsum:
1558 | 1| 3| 2| 5| 0| 4|
1559 e0 e1 e2 e3 e4 e5
1560
1561p_expert_cumsum(with M_a pad, and skip zero tokens):
1562 | 4| 4| 4| 8| 0| 4|
1563 e0 e1 e2 e3 e4 e5
1564
1565p_expert_cumsum
1566 | 0| 4| 8|12|20|20|24|
1567
1568local_expert_mask : [1, 0, 1, 1, 0, 1] (mask out expert-id=1, 4)
1569
1570p_m_cumsum
1571 | 0| 1| 1| 2| 3| 3| 4|
1572
1573*/
1574
1575// count topk_id into mesh
1576template <typename Problem_>
1578{
1580
1581 using IndexType = typename Problem::IndexType;
1582 using WeightType = typename Problem::WeightType;
1583 using MeshType = typename Problem::MeshType;
1584
1585 static constexpr index_t kBlockSize = 256;
1586 static constexpr index_t OCCUPANCY = 2; // hard coded
1587
1589
1591
1592 struct Kargs
1593 {
1594 const void* p_topk_ids; // [tokens, topk]
1595 const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
1596 void* p_expert_mesh; // [expert, tokens]
1597 index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
1598 // used for ws/LDS calculation
1600 index_t mesh_stride; // mesh_stride for p_expert_mesh
1602 };
1603
1604 CK_TILE_HOST static constexpr auto get_num_cu()
1605 {
1606 index_t num_cu = [&]() {
1607 hipDeviceProp_t dev_prop;
1608 hipDevice_t dev;
1609 HIP_CHECK_ERROR(hipGetDevice(&dev));
1610 HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
1611 return dev_prop.multiProcessorCount;
1612 }();
1613 return num_cu;
1614 }
1615
1616 CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
1617 {
1618 Kargs k;
1619 k.p_topk_ids = h.p_topk_ids;
1621 k.p_expert_mesh = h.p_ws;
1622 k.tokens = h.tokens;
1625 k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
1626 return k;
1627 }
1628
1629 CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; }
1630
1631 CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
1632
1633 // in byte
1634 CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; }
1635
1637 {
1639
1640 const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
1641 MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
1642 index_t tokens = [&]() {
1643 if constexpr(Problem::LocalToken)
1644 {
1645 return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
1646 }
1647 else
1648 {
1649 return kargs.tokens;
1650 }
1651 }();
1652 index_t rounded_tokens = [&]() {
1653 if constexpr(Problem::LocalToken)
1654 {
1655 return (tokens + Problem::SubTokenTile - 1) / Problem::SubTokenTile *
1656 Problem::SubTokenTile;
1657 }
1658 else
1659 return tokens;
1660 }();
1661 index_t mesh_stride = [&]() {
1662 if constexpr(Problem::LocalToken)
1663 {
1664 return impl::moe_sorting_mp_mesh_stride(tokens);
1665 }
1666 else
1667 {
1668 return kargs.mesh_stride;
1669 }
1670 }();
1671 index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
1672
1673#pragma unroll Problem::SubTokenTile
1674 for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < total_elem;
1675 i += gridDim.x * kBlockSize)
1676 {
1677 auto x = p_topk_ids[i];
1679 IndexType eid = x[j.value]; // ext_vector_type must use int to []
1680 uint32_t curr_token_id, curr_topk_id;
1681 kargs.topk_mdiv.divmod(i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
1682 if(eid < kargs.num_experts)
1683 {
1684 if constexpr(Problem::LocalToken)
1685 {
1686 if(static_cast<index_t>(curr_token_id) < tokens)
1687 p_expert_mesh[eid * mesh_stride + curr_token_id] =
1688 (curr_topk_id + 1) & 0xffff;
1689 }
1690 else
1691 p_expert_mesh[eid * mesh_stride + curr_token_id] =
1692 (curr_topk_id + 1) & 0xffff;
1693 }
1694 });
1695 }
1696 }
1697};
1698template <typename Problem_>
1700{
1702
1703 using IndexType = typename Problem::IndexType;
1704 using WeightType = typename Problem::WeightType;
1705 using MeshType = typename Problem::MeshType;
1706
1707 static constexpr index_t kBlockSize = 512;
1708
1710
1712
1713 struct Kargs
1714 {
1715 const void* p_topk_ids; // [tokens, topk]
1716 const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
1717 void* p_expert_mesh; // [expert, tokens]
1718 index_t tokens; // if p_local_tokens is not nullptr, this indicate the max possible tokens
1719 // used for ws/LDS calculation
1720 index_t mesh_stride; // mesh_stride for p_expert_mesh
1722
1723 const void* p_local_expert_mask; // [expert]
1724 void* p_expert_cumsum; // [expert]
1726 };
1727
1728 CK_TILE_HOST static constexpr auto get_num_cu()
1729 {
1730 index_t num_cu = [&]() {
1731 hipDeviceProp_t dev_prop;
1732 hipDevice_t dev;
1733 HIP_CHECK_ERROR(hipGetDevice(&dev));
1734 HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
1735 return dev_prop.multiProcessorCount;
1736 }();
1737 return num_cu;
1738 }
1739
1740 CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
1741 {
1742 Kargs k;
1743 k.p_topk_ids = h.p_topk_ids;
1745 k.p_expert_mesh = h.p_ws;
1746 k.p_expert_cumsum = reinterpret_cast<void*>(
1747 reinterpret_cast<char*>(h.p_ws) +
1749 k.tokens = h.tokens;
1751 k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
1754 return k;
1755 }
1756
1757 CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { return h.num_experts; }
1758
1759 CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
1760
1761 // in byte
1762 // CK_TILE_HOST static constexpr auto GetSmemSize() { return 0; }
1763 CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
1764 {
1765 return kBlockSize / get_warp_size() * sizeof(IndexType);
1766 }
1767
1769 {
1770 constexpr index_t index_pack = Problem::SubTokenTile; // always packed
1771 __shared__ char smem[GetSmemSize()];
1772 using topk_id_t = ext_vector_t<IndexType, index_pack>;
1773 const int eid = blockIdx.x;
1774 const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
1775 const IndexType* p_local_expert_mask =
1776 static_cast<const IndexType*>(kargs.p_local_expert_mask);
1777 IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
1778 index_t lane_id = threadIdx.x % get_warp_size();
1779 index_t wave_id = threadIdx.x / get_warp_size();
1780 const index_t tokens = [&]() {
1781 if constexpr(Problem::LocalToken)
1782 {
1783 return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
1784 }
1785 else
1786 {
1787 return kargs.tokens;
1788 }
1789 }();
1790 index_t rounded_tokens = [&]() {
1791 if constexpr(Problem::LocalToken)
1792 {
1793 return (tokens + index_pack - 1) / index_pack * index_pack;
1794 }
1795 else
1796 return tokens;
1797 }();
1798 index_t mesh_stride = [&]() {
1799 if constexpr(Problem::LocalToken)
1800 {
1801 return impl::moe_sorting_mp_mesh_stride(tokens);
1802 }
1803 else
1804 {
1805 return kargs.mesh_stride;
1806 }
1807 }();
1808
1809 IndexType mask = 1;
1810 if constexpr(Problem::LocalExpertMasking)
1811 {
1812 mask = p_local_expert_mask[eid];
1813 }
1814 MeshType* p_expert_mesh =
1815 reinterpret_cast<MeshType*>(kargs.p_expert_mesh) + eid * mesh_stride;
1816 for(index_t i = threadIdx.x; i < mesh_stride; i += kBlockSize)
1817 {
1818 p_expert_mesh[i] = 0;
1819 }
1821
1822 index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / index_pack;
1823
1824#pragma unroll index_pack
1825 for(index_t i = threadIdx.x; i < total_elem; i += kBlockSize)
1826 {
1827 auto x = p_topk_ids[i];
1828 static_for<0, index_pack, 1>{}([&](auto j) {
1829 IndexType eid_x = x[j.value]; // ext_vector_type must use int to []
1830 if(eid_x == eid)
1831 {
1832 uint32_t curr_token_id, curr_topk_id;
1833 kargs.topk_mdiv.divmod(i * index_pack + j, curr_token_id, curr_topk_id);
1834 if constexpr(Problem::LocalToken)
1835 {
1836 if(static_cast<index_t>(curr_token_id) < tokens)
1837 p_expert_mesh[curr_token_id] = (curr_topk_id + 1) & 0xffff;
1838 }
1839 else
1840 p_expert_mesh[curr_token_id] = (curr_topk_id + 1) & 0xffff;
1841 }
1842 });
1843 }
1845
1846 {
1847
1848 using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
1849 auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
1850 const r_t* p_expert_mesh_r = reinterpret_cast<r_t*>(p_expert_mesh);
1851
1852 int loops = (mesh_stride / index_pack + kBlockSize - 1) / kBlockSize;
1853
1854 if(Problem::LocalToken && mask == 0)
1855 return; // skip
1856 index_t cnt = 0; // per-wave cnt
1857 for(int i = 0; i < loops; i++)
1858 {
1859 int position = i * kBlockSize + threadIdx.x;
1860 r_t v{0};
1861 if(position < (mesh_stride / index_pack))
1862 v = p_expert_mesh_r[position];
1863 index_t local_sum = 0;
1865 [&](auto i_vec) { local_sum += v[i_vec.value] != 0 ? 1 : 0; });
1866 cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
1867 }
1868
1869 // reduce cross wave
1870 IndexType* s = reinterpret_cast<IndexType*>(smem);
1871 if(lane_id == 0)
1872 {
1873 s[wave_id] = cnt;
1874 }
1875 __syncthreads();
1876
1877 if(threadIdx.x == 0)
1878 {
1879 index_t c = 0;
1880 for(auto i = 0; i < (kBlockSize / get_warp_size()); i++)
1881 {
1882 c += s[i];
1883 }
1884 p_expert_cumsum[eid] = c;
1885 }
1886 }
1887 }
1888};
1889
1890// cnt total tokens for a expert
1891template <typename Problem_>
1893{
1895
1896 using IndexType = typename Problem::IndexType;
1897 using WeightType = typename Problem::WeightType;
1898 using MeshType = typename Problem::MeshType;
1899
1900 static constexpr index_t kBlockSize = 256;
1901 static constexpr index_t OCCUPANCY = 2; // hard coded
1902
1904
1906 struct Kargs
1907 {
1908 const void* p_local_expert_mask; // [expert]
1909 const void* p_local_tokens; // [1], if not nullptr, use this as actual tokens
1910 void* p_expert_mesh; // [expert, tokens]
1912 index_t mesh_stride; // mesh_stride for p_expert_mesh
1913 };
1914
1915 CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
1916 {
1917 Kargs k;
1920 k.p_expert_mesh = h.p_ws;
1921 k.p_expert_cumsum = reinterpret_cast<void*>(
1922 reinterpret_cast<char*>(h.p_ws) +
1925
1926 return k;
1927 }
1928
1929 CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { return dim3(h.num_experts); }
1930
1931 CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
1932
1933 // in byte
1934 CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
1935 {
1936 return kBlockSize / get_warp_size() * sizeof(IndexType);
1937 }
1938
1940 {
1941 __shared__ char smem[GetSmemSize()];
1942
1943 int eid = blockIdx.x;
1944 constexpr index_t index_pack = Problem::SubTokenTile; // always packed
1945 using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
1946
1947 const IndexType* p_local_expert_mask =
1948 static_cast<const IndexType*>(kargs.p_local_expert_mask);
1949 IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
1950
1951 auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
1952
1953 index_t tokens = [&]() {
1954 if constexpr(Problem::LocalToken)
1955 {
1956 return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
1957 }
1958 else
1959 {
1960 return 0; // will not use if not LocalToken
1961 }
1962 }();
1963
1964 index_t mesh_stride = [&]() {
1965 if constexpr(Problem::LocalToken)
1966 {
1967 return impl::moe_sorting_mp_mesh_stride(tokens);
1968 }
1969 else
1970 {
1971 return kargs.mesh_stride;
1972 }
1973 }();
1974
1975 r_t* p_expert_mesh = reinterpret_cast<r_t*>(
1976 reinterpret_cast<MeshType*>(kargs.p_expert_mesh) + eid * mesh_stride);
1977
1978 int loops = (mesh_stride / index_pack + kBlockSize - 1) / kBlockSize;
1979
1980 if constexpr(Problem::LocalExpertMasking)
1981 {
1982 IndexType mask = p_local_expert_mask[eid];
1983 if(mask == 0)
1984 return; // skip
1985 }
1986
1987 index_t cnt = 0; // per-wave cnt
1988 for(int i = 0; i < loops; i++)
1989 {
1990 int position = i * kBlockSize + threadIdx.x;
1991 r_t v{0};
1992 if(position < (mesh_stride / index_pack))
1993 v = p_expert_mesh[position];
1994 index_t local_sum = 0;
1996 [&](auto i_vec) { local_sum += v[i_vec.value] != 0 ? 1 : 0; });
1997 cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
1998 }
1999
2000 index_t lane_id = threadIdx.x % get_warp_size();
2001 index_t wave_id = threadIdx.x / get_warp_size();
2002
2003 // reduce cross wave
2004 IndexType* s = reinterpret_cast<IndexType*>(smem);
2005 if(lane_id == 0)
2006 {
2007 s[wave_id] = cnt;
2008 }
2009 __syncthreads();
2010
2011 if(threadIdx.x == 0)
2012 {
2013 index_t c = 0;
2014 for(auto i = 0; i < (kBlockSize / get_warp_size()); i++)
2015 {
2016 c += s[i];
2017 }
2018 p_expert_cumsum[eid] = c;
2019 }
2020 }
2021};
2022
2023#if MOE_SORTING_FUSE_MP_01
2024template <typename Problem_>
2025struct MoeSortingMultiPhaseKernel_P01
2026{
2027 using Problem = remove_cvref_t<Problem_>;
2028
2029 using IndexType = typename Problem::IndexType;
2030 using WeightType = typename Problem::WeightType;
2031 using MeshType = typename Problem::MeshType;
2032
2033 static constexpr index_t kBlockSize = 256;
2034 static constexpr index_t OCCUPANCY = 2; // hard coded
2035
2036 typedef MoeSortingHostArgs MoeSortingKargs;
2037
2038 using Hargs = MoeSortingHostArgs;
2039
2040 struct Kargs
2041 {
2042 const void* p_topk_ids; // [tokens, topk]
2043 const void* p_local_expert_mask; // [expert]
2044 const void* p_local_tokens; // [1]
2045 void* p_expert_mesh; // [expert, tokens]
2046 void* p_expert_cumsum; // [expert + 1]
2047 void* p_expert_sem; // [1]
2048 index_t tokens;
2049 index_t num_experts;
2050 index_t mesh_stride; // mesh_stride for p_expert_mesh
2051 index_t wg_count; // used for semaphore
2052 mdiv topk_mdiv;
2053 };
2054
2055 CK_TILE_HOST static constexpr auto get_num_cu()
2056 {
2057 index_t num_cu = [&]() {
2058 hipDeviceProp_t dev_prop;
2059 hipDevice_t dev;
2060 HIP_CHECK_ERROR(hipGetDevice(&dev));
2061 HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
2062 return dev_prop.multiProcessorCount;
2063 }();
2064 return num_cu;
2065 }
2066
2067 CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
2068 {
2069 Kargs k;
2070 k.p_topk_ids = h.p_topk_ids;
2071 k.p_local_expert_mask = h.p_local_expert_mask;
2072 k.p_local_tokens = h.p_local_tokens;
2073 k.p_expert_mesh = h.p_ws;
2074 k.p_expert_cumsum = reinterpret_cast<void*>(
2075 reinterpret_cast<char*>(h.p_ws) +
2076 impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk));
2077 k.p_expert_sem = reinterpret_cast<void*>(
2078 reinterpret_cast<char*>(h.p_ws) +
2079 impl::moe_sorting_mp_mesh_smem_size(h.tokens, h.num_experts, h.topk) +
2080 impl::moe_sorting_mp_cumsum_smem_size(h.num_experts));
2081 k.tokens = h.tokens;
2082 k.num_experts = h.num_experts;
2083 k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
2084 k.wg_count = [&]() {
2085 if constexpr(Problem::LocalToken)
2086 {
2087 return GridSize(h);
2088 }
2089 else
2090 {
2091 return WGCounts(h);
2092 }
2093 }();
2094 k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
2095 return k;
2096 }
2097
2098 CK_TILE_HOST static constexpr auto GridSize(const Hargs&) { return get_num_cu() * OCCUPANCY; }
2099
2100 CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
2101
2102 CK_TILE_HOST static constexpr auto WGCounts(const Hargs& h)
2103 {
2104 index_t total_elem = h.tokens * h.topk / Problem::SubTokenTile;
2105 index_t elem_cnt = (total_elem + kBlockSize - 1) / kBlockSize;
2106
2107 // no more than grid_size
2108 return min(elem_cnt, GridSize(h));
2109 }
2110
2111 // in byte
2112 CK_TILE_HOST static constexpr auto GetSmemSize()
2113 {
2114 return kBlockSize / get_warp_size() * sizeof(IndexType);
2115 }
2116
2117 CK_TILE_DEVICE void operator()(Kargs kargs) const
2118 {
2119 workgroup_barrier wb{reinterpret_cast<uint32_t*>(kargs.p_expert_sem)};
2120 index_t tokens = [&]() {
2121 if constexpr(Problem::LocalToken)
2122 {
2123 return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
2124 }
2125 else
2126 {
2127 return kargs.tokens;
2128 }
2129 }();
2130 index_t rounded_tokens = [&]() {
2131 if constexpr(Problem::LocalToken)
2132 {
2133 return (tokens + Problem::SubTokenTile - 1) / Problem::SubTokenTile *
2134 Problem::SubTokenTile;
2135 }
2136 else
2137 return tokens;
2138 }();
2139 index_t wg_count = [&]() {
2140 if constexpr(Problem::LocalToken)
2141 {
2142 index_t total_elem = rounded_tokens * kargs.topk / Problem::SubTokenTile;
2143 index_t elem_cnt = (total_elem + kBlockSize - 1) / kBlockSize;
2144
2145 // no more than grid_size
2146 return min(elem_cnt, kargs.wg_count);
2147 }
2148 else
2149 {
2150 return kargs.wg_count;
2151 }
2152 }();
2153
2154 {
2155 using topk_id_t = ext_vector_t<IndexType, Problem::SubTokenTile>;
2156
2157 const topk_id_t* p_topk_ids = reinterpret_cast<const topk_id_t*>(kargs.p_topk_ids);
2158 IndexType* p_expert_mesh = reinterpret_cast<IndexType*>(kargs.p_expert_mesh);
2159 index_t total_elem = rounded_tokens * kargs.topk_mdiv.divisor / Problem::SubTokenTile;
2160
2161#pragma unroll Problem::SubTokenTile
2162 for(index_t i = blockIdx.x * kBlockSize + threadIdx.x; i < total_elem;
2163 i += kBlockSize * gridDim.x)
2164 {
2165 auto x = p_topk_ids[i];
2166 static_for<0, Problem::SubTokenTile, 1>{}([&](auto j) {
2167 IndexType eid = x[j.value]; // ext_vector_type must use int to []
2168 uint32_t curr_token_id, curr_topk_id;
2169 kargs.topk_mdiv.divmod(
2170 i * Problem::SubTokenTile + j, curr_token_id, curr_topk_id);
2171 // p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] = curr_topk_id + 1;
2172 if constexpr(Problem::LocalToken)
2173 {
2174 if(static_cast<index_t>(curr_token_id) < tokens)
2175 p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
2176 (curr_topk_id + 1) & 0xffff;
2177 }
2178 else
2179 p_expert_mesh[eid * kargs.mesh_stride + curr_token_id] =
2180 (curr_topk_id + 1) & 0xffff;
2181 });
2182 }
2183 if(static_cast<index_t>(blockIdx.x) < wg_count)
2184 {
2185 wb.inc();
2186 }
2187 }
2188
2189 {
2190 __shared__ char smem[GetSmemSize()];
2191 int eid = blockIdx.x;
2192
2193 // early exist in case of extra atomic wait
2194 if(eid >= kargs.num_experts)
2195 return;
2196
2197 wb.wait_lt(wg_count);
2198
2199 for(; eid < kargs.num_experts; eid += gridDim.x)
2200 {
2201 // if(threadIdx.x == 0)
2202 // printf("!!! bid:%d, eid:%d (%d, %d)\n",
2203 // static_cast<int>(blockIdx.x),
2204 // eid,
2205 // kargs.num_experts,
2206 // static_cast<int>(blockDim.x));
2207 constexpr index_t index_pack = 4; // always packed
2208 using r_t = ext_vector_t<IndexType, index_pack>; // always use int32x4
2209 r_t* p_expert_mesh = reinterpret_cast<r_t*>(
2210 reinterpret_cast<index_t*>(kargs.p_expert_mesh) + eid * kargs.mesh_stride);
2211
2212 const IndexType* p_local_expert_mask =
2213 static_cast<const IndexType*>(kargs.p_local_expert_mask);
2214 IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
2215
2216 auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
2217
2218 int loops = (kargs.mesh_stride / index_pack + kBlockSize - 1) / kBlockSize;
2219
2220 if constexpr(Problem::LocalExpertMasking)
2221 {
2222 IndexType mask = p_local_expert_mask[eid];
2223 if(mask == 0)
2224 continue; // skip
2225 }
2226
2227 index_t cnt = 0; // per-wave cnt
2228 for(int i = 0; i < loops; i++)
2229 {
2230 int position = i * kBlockSize + threadIdx.x;
2231 r_t v{0};
2232 if(position < (kargs.mesh_stride / index_pack))
2233 v = p_expert_mesh[position];
2234 index_t local_sum = 0;
2235 static_for<0, index_pack, 1>{}(
2236 [&](auto i_vec) { local_sum += v[i_vec.value] != 0 ? 1 : 0; });
2237 cnt += impl::moe_sorting_wave_reduce(local_sum, f_sum);
2238 }
2239
2240 index_t lane_id = threadIdx.x % get_warp_size();
2241 index_t wave_id = threadIdx.x / get_warp_size();
2242
2243 // reduce cross wave
2244 IndexType* s = reinterpret_cast<IndexType*>(smem);
2245 __syncthreads();
2246 if(lane_id == 0)
2247 {
2248 s[wave_id] = cnt;
2249 }
2250 __syncthreads();
2251
2252 if(threadIdx.x == 0)
2253 {
2254 index_t c = 0;
2255 for(auto i = 0; i < (kBlockSize / get_warp_size()); i++)
2256 {
2257 c += s[i];
2258 }
2259 p_expert_cumsum[eid] = c;
2260 }
2261 }
2262 }
2263 }
2264};
2265#endif
2266
2267// token count cumsum
2268template <typename Problem_>
2270{
2272
2273 using IndexType = typename Problem::IndexType;
2274 using WeightType = typename Problem::WeightType;
2275 using MeshType = typename Problem::MeshType;
2276
2277 static constexpr index_t kBlockSize = 256;
2278 static constexpr index_t OCCUPANCY = 2; // hard coded
2279
2281
2283 struct Kargs
2284 {
2285 const void* p_local_expert_mask; // [expert]
2286 const void* p_local_tokens; // [1]
2287 void* p_expert_mesh; // [expert, tokens]
2288 void* p_expert_cumsum; // [expert + 1]
2294 index_t mesh_stride; // mesh_stride for p_expert_mesh
2297 };
2298
2299 CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
2300 {
2301 Kargs k;
2302 k.p_local_expert_mask = h.p_local_expert_mask;
2303 k.p_local_tokens = h.p_local_tokens;
2304 k.p_expert_cumsum = reinterpret_cast<void*>(
2305 reinterpret_cast<char*>(h.p_ws) +
2307 k.p_total_tokens_post_pad = h.p_total_tokens_post_pad;
2308 k.p_sorted_expert_ids = h.p_sorted_expert_ids;
2309
2310 k.p_moe_buf = h.p_moe_buf;
2311
2312 k.tokens = h.tokens;
2313 k.num_experts = h.num_experts;
2314 k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
2315 k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
2316
2317#if MOE_SORTING_FMOE_2D_BUF
2318 k.moe_buf_interm_dim = h.moe_buf_interm_dim;
2319 k.moe_buf_elem_bytes = h.moe_buf_elem_bytes;
2320#else
2321 k.moe_buf_bytes = h.moe_buf_bytes;
2322#endif
2323
2324 return k;
2325 }
2326
2327 CK_TILE_HOST static constexpr auto get_num_cu()
2328 {
2329 index_t num_cu = [&]() {
2330 hipDeviceProp_t dev_prop;
2331 hipDevice_t dev;
2332 HIP_CHECK_ERROR(hipGetDevice(&dev));
2333 HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
2334 return dev_prop.multiProcessorCount;
2335 }();
2336 return num_cu;
2337 }
2338
2339 CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
2340 {
2341#if MOE_SORTING_FMOE_2D_BUF
2342 return dim3(h.num_experts + get_num_cu() * OCCUPANCY);
2343#else
2344 // use 1 block to cumsum
2345 return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, kBlockSize * 16));
2346#endif
2347 }
2348
2349 CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
2350
2351 // in byte
2352 CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
2353 {
2354 // return 2 * kBlockSize * sizeof(IndexType);
2355 return (4 + 2 * kBlockSize / get_warp_size()) * sizeof(IndexType);
2356 }
2357
2358 // reduce single pixel within a wave
2360 {
2361 if(blockIdx.x > 0)
2362 {
2363#if MOE_SORTING_FMOE_2D_BUF
2365 kargs.tokens,
2366 kargs.moe_buf_interm_dim,
2367 kargs.moe_buf_elem_bytes,
2368 blockIdx.x - 1,
2369 gridDim.x - 1);
2370 return;
2371#else
2373 reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
2374 kargs.moe_buf_bytes,
2375 blockIdx.x - 1);
2376 return;
2377#endif
2378 }
2379 __shared__ char smem[GetSmemSize()];
2380 IndexType* s = reinterpret_cast<IndexType*>(smem);
2381
2382 const IndexType* p_local_expert_mask =
2383 static_cast<const IndexType*>(kargs.p_local_expert_mask);
2384 IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
2385 IndexType* p_total_tokens_post_pad =
2386 reinterpret_cast<IndexType*>(kargs.p_total_tokens_post_pad);
2387 IndexType* p_sorted_expert_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
2388
2389 const index_t loops = (kargs.num_experts + kBlockSize - 1) / kBlockSize;
2390 index_t wave_id = threadIdx.x / get_warp_size();
2391 index_t lane_id = threadIdx.x % get_warp_size();
2392
2393 IndexType prev_cumsum_a = 0;
2394 IndexType prev_cumsum_b = 0;
2395
2396 for(index_t i = 0; i < loops; i++)
2397 {
2398 index_t position = i * kBlockSize + threadIdx.x;
2399 IndexType a_ = 0; // token count for a expert
2400 IndexType b_ = 0; // mask for a expert
2401 if(position < kargs.num_experts)
2402 {
2403 a_ = p_expert_cumsum[position];
2404 if constexpr(Problem::LocalExpertMasking)
2405 b_ = p_local_expert_mask[position];
2406 }
2407
2408 int blocks_pers_expert =
2409 kargs.unit_size_mdiv.div(a_ + kargs.unit_size_mdiv.divisor - 1);
2410 // pad token
2411 int padded_blocks_per_expert = [&]() {
2412 int x_ = [&]() {
2413 if constexpr(Problem::SkipExpertsWithZeroTokens)
2414 {
2415 // if local_cnt is zero, blocks_pers_expert will be zero
2416 // this is what we want to achieve
2417 return blocks_pers_expert; // * kargs.unit_size_mdiv.divisor;
2418 }
2419 else
2420 {
2421 return max(blocks_pers_expert, 1);
2422 }
2423 }();
2424 if constexpr(Problem::LocalExpertMasking)
2425 {
2426 return b_ ? x_ : 0;
2427 }
2428 else
2429 return x_;
2430 }();
2431
2432 IndexType cumsum_a = padded_blocks_per_expert;
2433 IndexType cumsum_b = b_;
2434
2435 // Note: we first cumsum local round, then add previous cumsum
2438
2439 __syncthreads();
2440 if(lane_id == get_warp_size() - 1)
2441 {
2442 s[4 + wave_id] = cumsum_a;
2443 s[4 + wave_id + kBlockSize / get_warp_size()] = cumsum_b;
2444 }
2445
2446 __syncthreads();
2447
2448 // reduce cross wave
2449 static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
2450 IndexType prev_a = s[4 + i_w];
2451 IndexType prev_b = s[4 + i_w + kBlockSize / get_warp_size()];
2452 prev_a = wave_id > i_w ? prev_a : 0; // mask out
2453 prev_b = wave_id > i_w ? prev_b : 0; // mask out
2454 cumsum_a += prev_a;
2455 cumsum_b += prev_b;
2456 });
2457
2458 // Now let's add previous cumsum
2459 cumsum_a += prev_cumsum_a;
2460 cumsum_b += prev_cumsum_b;
2461
2462 if(threadIdx.x == kBlockSize - 1)
2463 {
2464 s[2] = cumsum_a; // store the last cumsum
2465 s[3] = cumsum_b;
2466 }
2467
2468 IndexType out_0 = cumsum_a - padded_blocks_per_expert; // exclusive cumsum tok cnt
2469 IndexType out_1 = cumsum_b - b_; // exclusive cumsum mask cnt
2470
2471 __syncthreads();
2472 prev_cumsum_a = s[2];
2473 prev_cumsum_b = s[3];
2474
2475 if(position < kargs.num_experts)
2476 {
2477 p_expert_cumsum[position] = out_0 * kargs.unit_size_mdiv.divisor;
2478 }
2479
2480 {
2481 if constexpr(Problem::LocalExpertMasking)
2482 {
2483 if(b_)
2484 {
2485 for(int j = 0; j < blocks_pers_expert; j++)
2486 {
2487 p_sorted_expert_ids[out_0 + j] = out_1;
2488 }
2489 }
2490 }
2491 else
2492 {
2493 for(int j = 0; j < blocks_pers_expert; j++)
2494 {
2495 p_sorted_expert_ids[out_0 + j] = position;
2496 }
2497 }
2498 }
2499 }
2500
2501 if(threadIdx.x == 0)
2502 {
2503 auto total_tokens_post_pad = prev_cumsum_a * kargs.unit_size_mdiv.divisor;
2504 p_total_tokens_post_pad[0] = total_tokens_post_pad;
2505 p_expert_cumsum[kargs.num_experts] = total_tokens_post_pad;
2506 }
2507 }
2508};
2509
2510template <typename Problem_>
2512{
2514
2515 using IndexType = typename Problem::IndexType;
2516 using WeightType = typename Problem::WeightType;
2517 using MeshType = typename Problem::MeshType;
2518
2519 static constexpr index_t kBlockSize = 256;
2520 static constexpr index_t OCCUPANCY = 2; // hard coded
2521
2523
2525
2526 struct Kargs
2527 {
2528 const void* p_weights;
2530 const void* p_local_tokens;
2533 void* p_expert_mesh; // [token, expert]
2535
2538 index_t mesh_stride; // mesh_stride for p_expert_mesh
2540 };
2541
2542 CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
2543 {
2544 Kargs k;
2545 k.p_weights = h.p_weights;
2546 k.p_local_expert_mask = h.p_local_expert_mask;
2547 k.p_local_tokens = h.p_local_tokens;
2548 k.p_sorted_token_ids = h.p_sorted_token_ids;
2549 k.p_sorted_weights = h.p_sorted_weights;
2550 k.p_expert_mesh = h.p_ws;
2551 k.p_expert_cumsum = reinterpret_cast<void*>(
2552 reinterpret_cast<char*>(h.p_ws) +
2554 k.tokens = h.tokens;
2555 k.num_experts = h.num_experts;
2556 k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
2557 k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
2558 return k;
2559 }
2560
2561 CK_TILE_HOST static constexpr auto GridSize(const Hargs& h) { return dim3(h.num_experts); }
2562
2563 CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
2564
2565 // in byte
2566 CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize()
2567 {
2568 return (4 + kBlockSize / get_warp_size()) * sizeof(IndexType);
2569 }
2570
2572 {
2573 __shared__ char smem[GetSmemSize()];
2574
2575 const IndexType* p_local_expert_mask =
2576 static_cast<const IndexType*>(kargs.p_local_expert_mask);
2577 IndexType* s = reinterpret_cast<IndexType*>(smem);
2578 IndexType* p_expert_mesh = reinterpret_cast<IndexType*>(kargs.p_expert_mesh);
2579 IndexType* p_sorted_token_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_token_ids);
2580 IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
2581 const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
2582 WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
2583
2584 index_t tokens = [&]() {
2585 if constexpr(Problem::LocalToken)
2586 {
2587 return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
2588 }
2589 else
2590 {
2591 return kargs.tokens;
2592 }
2593 }();
2594 int eid = blockIdx.x;
2595 int wave_id = threadIdx.x / get_warp_size();
2596 int lane_id = threadIdx.x % get_warp_size();
2597 int e_start = p_expert_cumsum[eid];
2598 int e_end = p_expert_cumsum[eid + 1];
2599 if constexpr(Problem::SkipExpertsWithZeroTokens)
2600 {
2601 if(e_start == e_end)
2602 return;
2603 }
2604
2605 if constexpr(Problem::LocalExpertMasking)
2606 {
2607 int e_mask = p_local_expert_mask[eid];
2608 if(e_mask == 0)
2609 return; // skip empty expert
2610 }
2611
2612 // cumsum one by one
2613 int loops = (kargs.mesh_stride + kBlockSize - 1) / kBlockSize;
2614 int prev_cumsum = 0;
2615 for(int i = 0; i < loops; i++)
2616 {
2617 int i_token = i * kBlockSize + threadIdx.x;
2618 IndexType x = 0;
2619 if(i_token < tokens)
2620 {
2621 x = p_expert_mesh[eid * kargs.mesh_stride + i_token];
2622 }
2623 int i_topk = x - 1; // topk of this token
2624 int i_show = x != 0 ? 1 : 0; // has this token or not
2625 int cumsum = i_show;
2627
2628 __syncthreads();
2629 if(lane_id == get_warp_size() - 1)
2630 {
2631 s[4 + wave_id] = cumsum;
2632 }
2633 __syncthreads();
2634
2635 // reduce cross wave
2636 static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
2637 IndexType prev = s[4 + i_w];
2638 prev = wave_id > i_w ? prev : 0; // mask out
2639 cumsum += prev;
2640 });
2641 cumsum += prev_cumsum; // add previous round cumsum
2642 if(threadIdx.x == kBlockSize - 1)
2643 {
2644 s[0] = cumsum;
2645 }
2646 __syncthreads();
2647
2648 int position = cumsum - i_show;
2649 prev_cumsum = s[0]; // update the last cumsum
2650
2651 if(i_show)
2652 {
2653#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
2654 p_sorted_token_ids[e_start + position] = MOE_SORTING_MOCK_ID(i_token, i_topk);
2655#else
2656 p_sorted_token_ids[e_start + position] = i_token;
2657#endif
2658 p_sorted_weights[e_start + position] =
2659 p_weights[i_token * kargs.topk_mdiv.divisor + i_topk];
2660 }
2661 }
2662
2663 for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += kBlockSize)
2664 {
2665#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
2666 p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor);
2667#else
2668 p_sorted_token_ids[i] = tokens;
2669#endif
2670 p_sorted_weights[i] = static_cast<WeightType>(0.0);
2671 }
2672 }
2673};
2674
2675namespace impl {
2676// we use dynamic LDS size here
2677CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_)
2678{
2679 constexpr index_t kBlockSize = 256; // hardcoded 256
2680 const index_t expert_cumsum_elem = num_experts_ + 1;
2681 return (4 + 2 * kBlockSize / get_warp_size() + expert_cumsum_elem) * sizeof(int);
2682}
2683} // namespace impl
2684
2685// token count cumsum
2686template <typename Problem_>
2688{
2690
2691 using IndexType = typename Problem::IndexType;
2692 using WeightType = typename Problem::WeightType;
2693 using MeshType = typename Problem::MeshType;
2694
2695 static constexpr index_t kBlockSize = 256;
2696 static constexpr index_t OCCUPANCY = 2; // hard coded
2697
2699
2701 struct Kargs
2702 {
2703 const void* p_weights;
2704 const void* p_local_expert_mask; // [expert]
2705 const void* p_local_tokens; // [1]
2706 void* p_expert_mesh; // [expert, tokens]
2707 void* p_expert_cumsum; // [expert + 1]
2710
2714
2717 index_t mesh_stride; // mesh_stride for p_expert_mesh
2720#if MOE_SORTING_FMOE_2D_BUF
2721 // NOTE:
2722 // moe_buf_* is a 2d ws buffer used for the following fmoe kernel
2723 // arranged as row*col, where row=tokens(or local_token), col=interm_dim
2724 // we fuse this clearing inside sorting kernel
2725 // Besides, we require inter_dim to be multiple of 16 byte(make sure when alloc ws for fmoe)
2726 index_t moe_buf_interm_dim; // p_moe_buf interm_dim
2727 index_t moe_buf_elem_bytes; // p_moe_buf byte size(8bit, 16bit, 32bit, etc.)
2728#else
2729 long_index_t moe_buf_bytes; // byte size of p_moe_buf
2730#endif
2731 };
2732
2733 CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
2734 {
2735 Kargs k;
2736 k.p_weights = h.p_weights;
2737 k.p_local_expert_mask = h.p_local_expert_mask;
2738 k.p_local_tokens = h.p_local_tokens;
2739 k.p_expert_mesh = h.p_ws;
2740 k.p_expert_cumsum = reinterpret_cast<void*>(
2741 reinterpret_cast<char*>(h.p_ws) +
2743 k.p_total_tokens_post_pad = h.p_total_tokens_post_pad;
2744 k.p_sorted_expert_ids = h.p_sorted_expert_ids;
2745
2746 k.p_sorted_token_ids = h.p_sorted_token_ids;
2747 k.p_sorted_weights = h.p_sorted_weights;
2748
2749 k.p_moe_buf = h.p_moe_buf;
2750
2751 k.tokens = h.tokens;
2752 k.num_experts = h.num_experts;
2753 k.mesh_stride = impl::moe_sorting_mp_mesh_stride(h.tokens);
2754 k.unit_size_mdiv = mdiv{static_cast<uint32_t>(h.unit_size)};
2755 k.topk_mdiv = mdiv{static_cast<uint32_t>(h.topk)};
2756
2757#if MOE_SORTING_FMOE_2D_BUF
2758 k.moe_buf_interm_dim = h.moe_buf_interm_dim;
2759 k.moe_buf_elem_bytes = h.moe_buf_elem_bytes;
2760#else
2761 k.moe_buf_bytes = h.moe_buf_bytes;
2762#endif
2763
2764 return k;
2765 }
2766
2767 CK_TILE_HOST static constexpr auto get_num_cu()
2768 {
2769 index_t num_cu = [&]() {
2770 hipDeviceProp_t dev_prop;
2771 hipDevice_t dev;
2772 HIP_CHECK_ERROR(hipGetDevice(&dev));
2773 HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
2774 return dev_prop.multiProcessorCount;
2775 }();
2776 return num_cu;
2777 }
2778
2779 CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
2780 {
2781#if MOE_SORTING_FMOE_2D_BUF
2782 return dim3(h.num_experts + get_num_cu() * OCCUPANCY);
2783#else
2784 // use 1 block to cumsum
2785 // return dim3(1 + ck_tile::integer_divide_ceil(h.moe_buf_bytes, kBlockSize * 16));
2787#endif
2788 }
2789
2790 CK_TILE_HOST static constexpr auto BlockSize(const Hargs&) { return dim3(kBlockSize); }
2791
2792 // only use this at host !
2793 CK_TILE_HOST static constexpr auto GetSmemSize(const Hargs& h)
2794 {
2795 const auto smem_23 = impl::moe_sorting_get_smem_size_p23(h.num_experts);
2796 const auto smem_sf = kBlockSize * 4 * sizeof(IndexType);
2797 return max(smem_23, smem_sf);
2798 }
2799
2800 // reduce single pixel within a wave
2802 {
2803 index_t tokens = [&]() {
2804 if constexpr(Problem::LocalToken)
2805 {
2806 return reinterpret_cast<const index_t*>(kargs.p_local_tokens)[0];
2807 }
2808 else
2809 {
2810 return kargs.tokens;
2811 }
2812 }();
2813
2814 if(static_cast<index_t>(blockIdx.x) >= kargs.num_experts)
2815 {
2816#if MOE_SORTING_FMOE_2D_BUF
2818 tokens,
2819 kargs.moe_buf_interm_dim,
2820 kargs.moe_buf_elem_bytes,
2821 blockIdx.x - kargs.num_experts,
2822 gridDim.x - kargs.num_experts);
2823 return;
2824#else
2826 reinterpret_cast<uint8x16_t*>(kargs.p_moe_buf),
2827 kargs.moe_buf_bytes,
2828 blockIdx.x - kargs.num_experts);
2829 return;
2830#endif
2831 }
2832
2833 extern __shared__ char smem[];
2834 {
2835 IndexType* s = reinterpret_cast<IndexType*>(smem);
2836
2837 const IndexType* p_local_expert_mask =
2838 static_cast<const IndexType*>(kargs.p_local_expert_mask);
2839 IndexType* p_expert_cumsum = reinterpret_cast<IndexType*>(kargs.p_expert_cumsum);
2840 IndexType* p_expert_cumsum_smem = s + 4 + 2 * kBlockSize / get_warp_size();
2841 IndexType* p_total_tokens_post_pad =
2842 reinterpret_cast<IndexType*>(kargs.p_total_tokens_post_pad);
2843 IndexType* p_sorted_expert_ids =
2844 reinterpret_cast<IndexType*>(kargs.p_sorted_expert_ids);
2845
2846 const index_t loops = (kargs.num_experts + kBlockSize - 1) / kBlockSize;
2847 index_t wave_id = threadIdx.x / get_warp_size();
2848 index_t lane_id = threadIdx.x % get_warp_size();
2849
2850 IndexType prev_cumsum_a = 0;
2851 IndexType prev_cumsum_b = 0;
2852
2853 for(index_t i = 0; i < loops; i++)
2854 {
2855 index_t position = i * kBlockSize + threadIdx.x;
2856 IndexType a_ = 0; // token count for a expert
2857 IndexType b_ = 0; // mask for a expert
2858 if(position < kargs.num_experts)
2859 {
2860 a_ = p_expert_cumsum[position];
2861 if constexpr(Problem::LocalExpertMasking)
2862 b_ = p_local_expert_mask[position];
2863 }
2864
2865 int blocks_pers_expert =
2866 kargs.unit_size_mdiv.div(a_ + kargs.unit_size_mdiv.divisor - 1);
2867 // pad token
2868 int padded_blocks_per_expert = [&]() {
2869 int x_ = [&]() {
2870 if constexpr(Problem::SkipExpertsWithZeroTokens)
2871 {
2872 // if local_cnt is zero, blocks_pers_expert will be zero
2873 // this is what we want to achieve
2874 return blocks_pers_expert; // * kargs.unit_size_mdiv.divisor;
2875 }
2876 else
2877 {
2878 return max(blocks_pers_expert, 1);
2879 }
2880 }();
2881 if constexpr(Problem::LocalExpertMasking)
2882 {
2883 return b_ ? x_ : 0;
2884 }
2885 else
2886 return x_;
2887 }();
2888
2889 IndexType cumsum_a = padded_blocks_per_expert;
2890 IndexType cumsum_b = b_;
2891
2892 // Note: we first cumsum local round, then add previous cumsum
2895
2896 __syncthreads();
2897 if(lane_id == get_warp_size() - 1)
2898 {
2899 s[4 + wave_id] = cumsum_a;
2900 s[4 + wave_id + kBlockSize / get_warp_size()] = cumsum_b;
2901 }
2902
2903 __syncthreads();
2904
2905 // reduce cross wave
2906 static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
2907 IndexType prev_a = s[4 + i_w];
2908 IndexType prev_b = s[4 + i_w + kBlockSize / get_warp_size()];
2909 prev_a = wave_id > i_w ? prev_a : 0; // mask out
2910 prev_b = wave_id > i_w ? prev_b : 0; // mask out
2911 cumsum_a += prev_a;
2912 cumsum_b += prev_b;
2913 });
2914
2915 // Now let's add previous cumsum
2916 cumsum_a += prev_cumsum_a;
2917 cumsum_b += prev_cumsum_b;
2918
2919 if(threadIdx.x == kBlockSize - 1)
2920 {
2921 s[2] = cumsum_a; // store the last cumsum
2922 s[3] = cumsum_b;
2923 }
2924
2925 IndexType out_0 = cumsum_a - padded_blocks_per_expert; // exclusive cumsum tok cnt
2926 IndexType out_1 = cumsum_b - b_; // exclusive cumsum mask cnt
2927
2928 __syncthreads();
2929 prev_cumsum_a = s[2];
2930 prev_cumsum_b = s[3];
2931
2932 if(position < kargs.num_experts)
2933 {
2934 p_expert_cumsum_smem[position] = out_0 * kargs.unit_size_mdiv.divisor;
2935 }
2936
2937 {
2938 if(blockIdx.x == 0)
2939 {
2940 if constexpr(Problem::LocalExpertMasking)
2941 {
2942 if(b_)
2943 {
2944 for(int j = 0; j < blocks_pers_expert; j++)
2945 {
2946 p_sorted_expert_ids[out_0 + j] = out_1;
2947 }
2948 }
2949 }
2950 else
2951 {
2952 for(int j = 0; j < blocks_pers_expert; j++)
2953 {
2954 p_sorted_expert_ids[out_0 + j] = position;
2955 }
2956 }
2957 }
2958 }
2959 }
2960
2961 if(threadIdx.x == 0)
2962 {
2963 auto total_tokens_post_pad = prev_cumsum_a * kargs.unit_size_mdiv.divisor;
2964 if(blockIdx.x == 0)
2965 {
2966 p_total_tokens_post_pad[0] = total_tokens_post_pad;
2967 p_total_tokens_post_pad[1] = tokens;
2968 }
2969 p_expert_cumsum_smem[kargs.num_experts] = total_tokens_post_pad;
2970 }
2971 }
2972
2973 __syncthreads();
2974 {
2975 const IndexType* p_local_expert_mask =
2976 static_cast<const IndexType*>(kargs.p_local_expert_mask);
2977 IndexType* s = reinterpret_cast<IndexType*>(smem);
2978 MeshType* p_expert_mesh = reinterpret_cast<MeshType*>(kargs.p_expert_mesh);
2979 IndexType* p_sorted_token_ids = reinterpret_cast<IndexType*>(kargs.p_sorted_token_ids);
2980 IndexType* p_expert_cumsum_smem = s + 4 + 2 * kBlockSize / get_warp_size();
2981 const WeightType* p_weights = static_cast<const WeightType*>(kargs.p_weights);
2982 WeightType* p_sorted_weights = reinterpret_cast<WeightType*>(kargs.p_sorted_weights);
2983
2984 int eid = blockIdx.x;
2985 int wave_id = threadIdx.x / get_warp_size();
2986 int lane_id = threadIdx.x % get_warp_size();
2987 int e_start = p_expert_cumsum_smem[eid];
2988 int e_end = p_expert_cumsum_smem[eid + 1];
2989 if constexpr(Problem::SkipExpertsWithZeroTokens)
2990 {
2991 if(e_start == e_end)
2992 return;
2993 }
2994
2995 if constexpr(Problem::LocalExpertMasking)
2996 {
2997 int e_mask = p_local_expert_mask[eid];
2998 if(e_mask == 0)
2999 return; // skip empty expert
3000 }
3001
3002 index_t mesh_stride = [&]() {
3003 if constexpr(Problem::LocalToken)
3004 {
3005 return impl::moe_sorting_mp_mesh_stride(tokens);
3006 }
3007 else
3008 {
3009 return kargs.mesh_stride;
3010 }
3011 }();
3012
3013 // cumsum one by one
3014 constexpr index_t index_pack = Problem::SubTokenTile; // always packed
3015 using r_t = ext_vector_t<MeshType, index_pack>; // always use int32x4
3017 int loops = (mesh_stride / index_pack + kBlockSize - 1) / kBlockSize;
3018
3019 int prev_cumsum = 0;
3020
3021 for(int i = 0; i < loops; i++)
3022 {
3023 int i_token_pack = i * kBlockSize + threadIdx.x;
3024 r_t x_v = 0;
3025 if(i_token_pack < (tokens + index_pack - 1) / index_pack)
3026 {
3027 x_v = reinterpret_cast<r_t*>(p_expert_mesh + eid * mesh_stride)[i_token_pack];
3028 }
3029
3030 r_t x_r;
3031#if 0
3032 if constexpr(index_pack != 1)
3033 {
3034 // shuffle, we must have contiguout thread holds contiguout token
3035 __syncthreads();
3036 reinterpret_cast<r_t*>(s)[threadIdx.x] = x_v;
3037 __syncthreads();
3038
3039 static_for<0, index_pack, 1>{}([&](auto j_) {
3040 constexpr auto j = j_.value;
3041 x_r[j] = reinterpret_cast<MeshType*>(s)[threadIdx.x + j * kBlockSize];
3042 });
3043 }
3044#else
3045 x_r = x_v;
3046#endif
3047 {
3048#if 0
3049#pragma unroll
3050 for(int j = 0; j < index_pack / 2; j++)
3051 {
3052 int i_token = i * kBlockSize * index_pack + threadIdx.x + j * kBlockSize;
3053 index_t x = x_d[j];
3054 int i_topk = x - 1; // topk of this token
3055 int i_show = x != 0 ? 1 : 0; // has this token or not
3056 int cumsum = i_show;
3058
3059 __syncthreads();
3060 if(lane_id == get_warp_size() - 1)
3061 {
3062 s[4 + wave_id] = cumsum;
3063 }
3064 __syncthreads();
3065
3066 // reduce cross wave
3067 static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
3068 IndexType prev = s[4 + i_w];
3069 prev = wave_id > i_w ? prev : 0; // mask out
3070 cumsum += prev;
3071 });
3072 cumsum += prev_cumsum; // add previous round cumsum
3073 if(threadIdx.x == kBlockSize - 1)
3074 {
3075 s[0] = cumsum;
3076 }
3077 __syncthreads();
3078
3079 int position = cumsum - i_show;
3080 prev_cumsum = s[0]; // update the last cumsum
3081
3082 if(i_show)
3083 {
3084#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
3085 p_sorted_token_ids[e_start + position] =
3086 MOE_SORTING_MOCK_ID(i_token, i_topk);
3087#else
3088 p_sorted_token_ids[e_start + position] = i_token;
3089#endif
3090 p_sorted_weights[e_start + position] =
3091 p_weights[i_token * kargs.topk_mdiv.divisor + i_topk];
3092 }
3093 }
3094#endif
3095 {
3096 d_t i_topk;
3097 d_t i_show;
3098 // = 0;
3099 int cumsum_store = 0;
3100
3101 static_for<0, index_pack, 1>{}([&](auto j_) {
3102 constexpr auto j = j_.value;
3103 i_topk[j] = static_cast<index_t>(x_r[j] - 1);
3104 i_show[j] = static_cast<index_t>(x_r[j] != 0 ? 1 : 0);
3105 cumsum_store += i_show[j];
3106 });
3107 int cumsum = cumsum_store;
3109
3110 __syncthreads();
3111 if(lane_id == get_warp_size() - 1)
3112 {
3113 s[4 + wave_id] = cumsum;
3114 }
3115 __syncthreads();
3116
3117 // reduce cross wave
3118 static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
3119 IndexType prev = s[4 + i_w];
3120 prev = wave_id > i_w ? prev : 0; // mask out
3121 cumsum += prev;
3122 });
3123 cumsum += prev_cumsum; // add previous round cumsum
3124 if(threadIdx.x == kBlockSize - 1)
3125 {
3126 s[0] = cumsum;
3127 }
3128 __syncthreads();
3129 prev_cumsum = s[0]; // update the last cumsum
3130
3131 int position = cumsum - cumsum_store;
3132 static_for<0, index_pack, 1>{}([&](auto j_) {
3133 constexpr auto j = j_.value;
3134 // int i_token = i * kBlockSize * index_pack + threadIdx.x + j *
3135 // kBlockSize;
3136 int i_token =
3137 i * kBlockSize * index_pack + threadIdx.x * index_pack + j;
3138
3139 if(i_show[j])
3140 {
3141#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
3142 p_sorted_token_ids[e_start + position] =
3143 MOE_SORTING_MOCK_ID(i_token, i_topk[j]);
3144#else
3145 p_sorted_token_ids[e_start + position] = i_token;
3146#endif
3147 p_sorted_weights[e_start + position] =
3148 p_weights[i_token * kargs.topk_mdiv.divisor + i_topk[j]];
3149 }
3150 position += i_show[j];
3151 });
3152
3153#if 0
3154 int i_token = i * kBlockSize * index_pack + threadIdx.x * 2 + j * kBlockSize * 2;
3155 index_t x = x_d[j];
3156 index_t x0 = static_cast<index_t>(x & 0xffff);
3157 index_t x1 = static_cast<index_t>(x >> 16);
3158 int i_topk_0 = x0 - 1; // topk of this token
3159 int i_show_0 = x0 != 0 ? 1 : 0; // has this token or not
3160 int i_topk_1 = x1 - 1; // topk of this token
3161 int i_show_1 = x1 != 0 ? 1 : 0; // has this token or not
3162 int cumsum = i_show_0 + i_show_1;
3164
3165 __syncthreads();
3166 if(lane_id == get_warp_size() - 1)
3167 {
3168 s[4 + wave_id] = cumsum;
3169 }
3170 __syncthreads();
3171
3172 // reduce cross wave
3173 static_for<0, kBlockSize / get_warp_size() - 1, 1>{}([&](auto i_w) {
3174 IndexType prev = s[4 + i_w];
3175 prev = wave_id > i_w ? prev : 0; // mask out
3176 cumsum += prev;
3177 });
3178 cumsum += prev_cumsum; // add previous round cumsum
3179 if(threadIdx.x == kBlockSize - 1)
3180 {
3181 s[0] = cumsum;
3182 }
3183 __syncthreads();
3184
3185 int position_0 = cumsum - i_show_0 - i_show_1;
3186 prev_cumsum = s[0]; // update the last cumsum
3187
3188 if(i_show_0)
3189 {
3190#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
3191 p_sorted_token_ids[e_start + position_0] =
3192 MOE_SORTING_MOCK_ID(i_token, i_topk_0);
3193#else
3194 p_sorted_token_ids[e_start + position_0] = i_token;
3195#endif
3196 p_sorted_weights[e_start + position_0] =
3197 p_weights[i_token * kargs.topk_mdiv.divisor + i_topk_0];
3198 }
3199
3200 int position_1 = cumsum - i_show_1;
3201
3202 if(i_show_1)
3203 {
3204#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
3205 p_sorted_token_ids[e_start + position_1] =
3206 MOE_SORTING_MOCK_ID(i_token + 1, i_topk_1);
3207#else
3208 p_sorted_token_ids[e_start + position_1] = i_token + 1;
3209#endif
3210 p_sorted_weights[e_start + position_1] =
3211 p_weights[(i_token + 1) * kargs.topk_mdiv.divisor + i_topk_1];
3212 }
3213#endif
3214 }
3215 }
3216 }
3217
3218 for(index_t i = e_start + prev_cumsum + threadIdx.x; i < e_end; i += kBlockSize)
3219 {
3220#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
3221 p_sorted_token_ids[i] = MOE_SORTING_MOCK_ID(tokens, kargs.topk_mdiv.divisor);
3222#else
3223 p_sorted_token_ids[i] = tokens;
3224#endif
3225 p_sorted_weights[i] = static_cast<WeightType>(0.0);
3226 }
3227 }
3228 }
3229};
3230
3231#undef MOE_SORTING_MOCK_ID
3232
3233} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition host_utility/hip_check_error.hpp:21
__host__ __device__ constexpr T min(T x)
Definition utility/math.hpp:116
CK_TILE_HOST_DEVICE index_t moe_sorting_mp_cumsum_smem_size(index_t num_experts)
Definition moe_sorting_kernel.hpp:1191
CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t *buf, long_index_t buf_bytes, index_t gid)
Definition moe_sorting_kernel.hpp:1348
CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_stride(index_t tokens)
Definition moe_sorting_kernel.hpp:1157
CK_TILE_DEVICE constexpr T moe_sorting_wave_reduce(T local, F reduce_f, number< wave_size_ >={})
Definition moe_sorting_kernel.hpp:1205
CK_TILE_DEVICE void moe_buf_set_zero_kernel_2d(void *buf, index_t row, index_t col, index_t elem_bytes, index_t gid, index_t blocks)
Definition moe_sorting_kernel.hpp:1359
CK_TILE_DEVICE void moe_sorting_wave_cumsum(data_t &thread_data)
Definition moe_sorting_kernel.hpp:1236
CK_TILE_HOST_DEVICE index_t moe_sorting_mp_sem_smem_size()
Definition moe_sorting_kernel.hpp:1198
CK_TILE_HOST constexpr auto moe_sorting_get_smem_size_p23(int num_experts_)
Definition moe_sorting_kernel.hpp:2677
CK_TILE_HOST index_t moe_sorting_mesh_byte_size(index_t tokens_, index_t, index_t topk_)
Definition moe_sorting_kernel.hpp:1166
CK_TILE_HOST_DEVICE index_t moe_sorting_mp_mesh_smem_size(index_t tokens, index_t num_experts, index_t topk)
Definition moe_sorting_kernel.hpp:1182
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
int64_t long_index_t
Definition integer.hpp:11
CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity()
Definition arch.hpp:328
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
uint8_t uint8x16_t
Definition vector_type.hpp:202
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y)
Definition tile/core/numeric/math.hpp:155
CK_TILE_HOST index_t moe_sorting_get_sub_token(int tokens_, int num_experts_)
Definition moe_sorting_kernel.hpp:180
CK_TILE_DEVICE void s_waitcnt()
Definition arch.hpp:241
CK_TILE_HOST index_t moe_sorting_get_workspace_size(int tokens_, int num_experts_, int topk_, int dispatch_policy_)
Definition moe_sorting_kernel.hpp:1409
typename impl::ext_vector< T, N >::type ext_vector_t
Definition vector_type.hpp:84
CK_TILE_HOST index_t moe_sorting_mp_get_workspace_size(int tokens_, int num_experts_, int topk_)
Definition moe_sorting_kernel.hpp:1394
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt=0)
Definition arch.hpp:121
CK_TILE_HOST bool moe_sorting_is_oneshot(int tokens_, int num_experts_)
Definition moe_sorting_kernel.hpp:1380
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST constexpr auto moe_sorting_get_smem_row_col(int tokens_, int num_experts_)
Definition moe_sorting_kernel.hpp:133
int32_t index_t
Definition ck.hpp:299
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_)
Definition reference_moe_sorting.hpp:11
unsigned int uint32_t
Definition stdint.h:126
signed int int32_t
Definition stdint.h:123
Definition moe_sorting_kernel.hpp:1450
index_t mesh_stride
Definition moe_sorting_kernel.hpp:1456
void * p_expert_mesh
Definition moe_sorting_kernel.hpp:1452
index_t tokens
Definition moe_sorting_kernel.hpp:1453
const void * p_local_tokens
Definition moe_sorting_kernel.hpp:1451
index_t mesh_byte_size
Definition moe_sorting_kernel.hpp:1457
index_t num_experts
Definition moe_sorting_kernel.hpp:1455
Definition moe_sorting_kernel.hpp:1442
static CK_TILE_HOST constexpr auto get_num_cu()
Definition moe_sorting_kernel.hpp:1460
MoeSortingHostArgs Hargs
Definition moe_sorting_kernel.hpp:1447
static constexpr index_t kBlockSize
Definition moe_sorting_kernel.hpp:1444
static CK_TILE_HOST constexpr auto GetSmemSize()
Definition moe_sorting_kernel.hpp:1489
static constexpr index_t OCCUPANCY
Definition moe_sorting_kernel.hpp:1445
static CK_TILE_HOST constexpr auto BlockSize(const Hargs &)
Definition moe_sorting_kernel.hpp:1486
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition moe_sorting_kernel.hpp:1491
remove_cvref_t< Problem_ > Problem
Definition moe_sorting_kernel.hpp:1443
static CK_TILE_HOST constexpr auto MakeKargs(const Hargs &h)
Definition moe_sorting_kernel.hpp:1472
static CK_TILE_HOST constexpr auto GridSize(const Hargs &)
Definition moe_sorting_kernel.hpp:1484
Definition moe_sorting_kernel.hpp:189
void * p_ws
Definition moe_sorting_kernel.hpp:203
index_t tokens
Definition moe_sorting_kernel.hpp:206
long_index_t moe_buf_bytes
Definition moe_sorting_kernel.hpp:219
void * p_moe_buf
Definition moe_sorting_kernel.hpp:202
void * p_total_tokens_post_pad
Definition moe_sorting_kernel.hpp:199
const void * p_topk_ids
Definition moe_sorting_kernel.hpp:190
const void * p_weights
Definition moe_sorting_kernel.hpp:191
void * p_sorted_token_ids
Definition moe_sorting_kernel.hpp:196
void * p_sorted_weights
Definition moe_sorting_kernel.hpp:197
index_t unit_size
Definition moe_sorting_kernel.hpp:207
const void * p_local_expert_mask
Definition moe_sorting_kernel.hpp:193
void * p_sorted_expert_ids
Definition moe_sorting_kernel.hpp:198
index_t topk
Definition moe_sorting_kernel.hpp:209
index_t num_experts
Definition moe_sorting_kernel.hpp:208
const void * p_local_tokens
Definition moe_sorting_kernel.hpp:194
Definition moe_sorting_kernel.hpp:240
void * p_moe_buf
Definition moe_sorting_kernel.hpp:249
void * p_total_tokens_post_pad
Definition moe_sorting_kernel.hpp:248
mdiv unit_size_mdiv
Definition moe_sorting_kernel.hpp:260
index_t smem_rows
Definition moe_sorting_kernel.hpp:259
index_t tokens
Definition moe_sorting_kernel.hpp:250
void * p_sorted_weights
Definition moe_sorting_kernel.hpp:246
index_t num_experts
Definition moe_sorting_kernel.hpp:251
mdiv expert_mdiv
Definition moe_sorting_kernel.hpp:262
long_index_t moe_buf_bytes
Definition moe_sorting_kernel.hpp:256
void * p_sorted_token_ids
Definition moe_sorting_kernel.hpp:245
index_t tokens_per_thread
Definition moe_sorting_kernel.hpp:258
const void * p_weights
Definition moe_sorting_kernel.hpp:242
const void * p_local_expert_mask
Definition moe_sorting_kernel.hpp:243
const void * p_local_tokens
Definition moe_sorting_kernel.hpp:244
mdiv topk_mdiv
Definition moe_sorting_kernel.hpp:261
void * p_sorted_expert_ids
Definition moe_sorting_kernel.hpp:247
const void * p_topk_ids
Definition moe_sorting_kernel.hpp:241
Definition moe_sorting_kernel.hpp:738
CK_TILE_DEVICE simple_smem_indexer(index_t *smem_, index_t row_stride_)
Definition moe_sorting_kernel.hpp:743
CK_TILE_DEVICE index_t & operator()(index_t i_row, index_t i_col)
Definition moe_sorting_kernel.hpp:751
index_t * smem
Definition moe_sorting_kernel.hpp:739
CK_TILE_DEVICE simple_smem_indexer(index_t *smem_)
Definition moe_sorting_kernel.hpp:757
CK_TILE_DEVICE index_t & operator()(index_t idx)
Definition moe_sorting_kernel.hpp:759
index_t row_stride
Definition moe_sorting_kernel.hpp:740
CK_TILE_DEVICE const index_t & operator()(index_t idx) const
Definition moe_sorting_kernel.hpp:758
CK_TILE_DEVICE const index_t & operator()(index_t i_row, index_t i_col) const
Definition moe_sorting_kernel.hpp:747
Definition moe_sorting_kernel.hpp:226
typename Problem::IndexType IndexType
Definition moe_sorting_kernel.hpp:229
static constexpr index_t kBlockSize
Definition moe_sorting_kernel.hpp:236
CK_TILE_DEVICE void moe_buf_set_zero_kernel_2d(void *buf, index_t row, index_t col, index_t elem_bytes) const
Definition moe_sorting_kernel.hpp:500
typename Problem::WeightType WeightType
Definition moe_sorting_kernel.hpp:230
CK_TILE_DEVICE void moe_align_block_size_kernel(const IndexType *__restrict__ topk_id, const WeightType *__restrict__ weights, index_t *p_sorted_token_ids, WeightType *p_sorted_weights, index_t *p_sorted_expert_ids, index_t *p_total_tokens_post_pad, const index_t num_experts, const index_t tokens_per_thread, const index_t numel, const mdiv unit_size_mdiv, const mdiv topk_mdiv, void *smem) const
Definition moe_sorting_kernel.hpp:517
CK_TILE_DEVICE void moe_align_block_size_kernel_ex(const IndexType *__restrict__ topk_id, const WeightType *__restrict__ weights, const IndexType *__restrict__ local_expert_mask, index_t *p_sorted_token_ids, WeightType *p_sorted_weights, index_t *p_sorted_expert_ids, index_t *p_total_tokens_post_pad, const index_t num_experts, const index_t tokens, const mdiv unit_size_mdiv, const mdiv topk_mdiv, const mdiv expert_mdiv, const index_t smem_rows, void *smem) const
Definition moe_sorting_kernel.hpp:763
static CK_TILE_HOST constexpr auto GridSize(const Hargs &h)
Definition moe_sorting_kernel.hpp:278
__device__ void wave_cumsum(data_t &thread_data) const
Definition moe_sorting_kernel.hpp:354
CK_TILE_DEVICE void moe_buf_set_zero_kernel(uint8x16_t *buf, long_index_t buf_bytes) const
Definition moe_sorting_kernel.hpp:490
static CK_TILE_HOST constexpr auto get_num_cu()
Definition moe_sorting_kernel.hpp:266
MoeSortingHostArgs MoeSortingKargs
Definition moe_sorting_kernel.hpp:232
static __device__ constexpr T wave_reduce(T local, F reduce_f, number< wave_size_ >={})
Definition moe_sorting_kernel.hpp:457
CK_TILE_DEVICE index_t calc_index(index_t total_col, index_t row, index_t col) const
Definition moe_sorting_kernel.hpp:485
remove_cvref_t< Problem_ > Problem
Definition moe_sorting_kernel.hpp:227
static constexpr index_t OCCUPANCY
Definition moe_sorting_kernel.hpp:237
static CK_TILE_HOST constexpr auto MakeKargs(const Hargs &h)
Definition moe_sorting_kernel.hpp:312
static CK_TILE_HOST constexpr auto BlockSize(const Hargs &h)
Definition moe_sorting_kernel.hpp:289
static CK_TILE_HOST constexpr auto GetSmemSize(const Hargs &h)
Definition moe_sorting_kernel.hpp:300
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition moe_sorting_kernel.hpp:1090
MoeSortingHostArgs Hargs
Definition moe_sorting_kernel.hpp:234
Definition moe_sorting_kernel.hpp:1593
void * p_expert_mesh
Definition moe_sorting_kernel.hpp:1596
const void * p_local_tokens
Definition moe_sorting_kernel.hpp:1595
index_t tokens
Definition moe_sorting_kernel.hpp:1597
index_t num_experts
Definition moe_sorting_kernel.hpp:1599
index_t mesh_stride
Definition moe_sorting_kernel.hpp:1600
const void * p_topk_ids
Definition moe_sorting_kernel.hpp:1594
mdiv topk_mdiv
Definition moe_sorting_kernel.hpp:1601
Definition moe_sorting_kernel.hpp:1578
typename Problem::MeshType MeshType
Definition moe_sorting_kernel.hpp:1583
static constexpr index_t OCCUPANCY
Definition moe_sorting_kernel.hpp:1586
static CK_TILE_HOST constexpr auto BlockSize(const Hargs &)
Definition moe_sorting_kernel.hpp:1631
typename Problem::WeightType WeightType
Definition moe_sorting_kernel.hpp:1582
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition moe_sorting_kernel.hpp:1636
static CK_TILE_HOST constexpr auto MakeKargs(const Hargs &h)
Definition moe_sorting_kernel.hpp:1616
static constexpr index_t kBlockSize
Definition moe_sorting_kernel.hpp:1585
typename Problem::IndexType IndexType
Definition moe_sorting_kernel.hpp:1581
static CK_TILE_HOST constexpr auto GetSmemSize()
Definition moe_sorting_kernel.hpp:1634
static CK_TILE_HOST constexpr auto GridSize(const Hargs &)
Definition moe_sorting_kernel.hpp:1629
MoeSortingHostArgs Hargs
Definition moe_sorting_kernel.hpp:1590
remove_cvref_t< Problem_ > Problem
Definition moe_sorting_kernel.hpp:1579
MoeSortingHostArgs MoeSortingKargs
Definition moe_sorting_kernel.hpp:1588
static CK_TILE_HOST constexpr auto get_num_cu()
Definition moe_sorting_kernel.hpp:1604
Definition moe_sorting_kernel.hpp:1714
const void * p_local_expert_mask
Definition moe_sorting_kernel.hpp:1723
index_t num_experts
Definition moe_sorting_kernel.hpp:1725
index_t mesh_stride
Definition moe_sorting_kernel.hpp:1720
void * p_expert_mesh
Definition moe_sorting_kernel.hpp:1717
const void * p_topk_ids
Definition moe_sorting_kernel.hpp:1715
void * p_expert_cumsum
Definition moe_sorting_kernel.hpp:1724
index_t tokens
Definition moe_sorting_kernel.hpp:1718
mdiv topk_mdiv
Definition moe_sorting_kernel.hpp:1721
const void * p_local_tokens
Definition moe_sorting_kernel.hpp:1716
Definition moe_sorting_kernel.hpp:1700
typename Problem::IndexType IndexType
Definition moe_sorting_kernel.hpp:1703
typename Problem::WeightType WeightType
Definition moe_sorting_kernel.hpp:1704
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition moe_sorting_kernel.hpp:1768
static CK_TILE_HOST constexpr auto get_num_cu()
Definition moe_sorting_kernel.hpp:1728
static CK_TILE_HOST_DEVICE constexpr auto GetSmemSize()
Definition moe_sorting_kernel.hpp:1763
remove_cvref_t< Problem_ > Problem
Definition moe_sorting_kernel.hpp:1701
typename Problem::MeshType MeshType
Definition moe_sorting_kernel.hpp:1705
MoeSortingHostArgs Hargs
Definition moe_sorting_kernel.hpp:1711
static constexpr index_t kBlockSize
Definition moe_sorting_kernel.hpp:1707
static CK_TILE_HOST constexpr auto BlockSize(const Hargs &)
Definition moe_sorting_kernel.hpp:1759
static CK_TILE_HOST constexpr auto GridSize(const Hargs &h)
Definition moe_sorting_kernel.hpp:1757
MoeSortingHostArgs MoeSortingKargs
Definition moe_sorting_kernel.hpp:1709
static CK_TILE_HOST constexpr auto MakeKargs(const Hargs &h)
Definition moe_sorting_kernel.hpp:1740
Definition moe_sorting_kernel.hpp:1907
void * p_expert_cumsum
Definition moe_sorting_kernel.hpp:1911
const void * p_local_tokens
Definition moe_sorting_kernel.hpp:1909
void * p_expert_mesh
Definition moe_sorting_kernel.hpp:1910
const void * p_local_expert_mask
Definition moe_sorting_kernel.hpp:1908
index_t mesh_stride
Definition moe_sorting_kernel.hpp:1912
Definition moe_sorting_kernel.hpp:1893
typename Problem::WeightType WeightType
Definition moe_sorting_kernel.hpp:1897
typename Problem::MeshType MeshType
Definition moe_sorting_kernel.hpp:1898
remove_cvref_t< Problem_ > Problem
Definition moe_sorting_kernel.hpp:1894
MoeSortingHostArgs Hargs
Definition moe_sorting_kernel.hpp:1905
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition moe_sorting_kernel.hpp:1939
static constexpr index_t OCCUPANCY
Definition moe_sorting_kernel.hpp:1901
typename Problem::IndexType IndexType
Definition moe_sorting_kernel.hpp:1896
static CK_TILE_HOST_DEVICE constexpr auto GetSmemSize()
Definition moe_sorting_kernel.hpp:1934
static CK_TILE_HOST constexpr auto GridSize(const Hargs &h)
Definition moe_sorting_kernel.hpp:1929
static constexpr index_t kBlockSize
Definition moe_sorting_kernel.hpp:1900
static CK_TILE_HOST constexpr auto MakeKargs(const Hargs &h)
Definition moe_sorting_kernel.hpp:1915
static CK_TILE_HOST constexpr auto BlockSize(const Hargs &)
Definition moe_sorting_kernel.hpp:1931
MoeSortingHostArgs MoeSortingKargs
Definition moe_sorting_kernel.hpp:1903
Definition moe_sorting_kernel.hpp:2702
void * p_expert_cumsum
Definition moe_sorting_kernel.hpp:2707
const void * p_weights
Definition moe_sorting_kernel.hpp:2703
index_t num_experts
Definition moe_sorting_kernel.hpp:2716
long_index_t moe_buf_bytes
Definition moe_sorting_kernel.hpp:2729
index_t tokens
Definition moe_sorting_kernel.hpp:2715
void * p_total_tokens_post_pad
Definition moe_sorting_kernel.hpp:2708
const void * p_local_tokens
Definition moe_sorting_kernel.hpp:2705
index_t mesh_stride
Definition moe_sorting_kernel.hpp:2717
void * p_sorted_expert_ids
Definition moe_sorting_kernel.hpp:2709
void * p_moe_buf
Definition moe_sorting_kernel.hpp:2713
void * p_expert_mesh
Definition moe_sorting_kernel.hpp:2706
void * p_sorted_token_ids
Definition moe_sorting_kernel.hpp:2711
const void * p_local_expert_mask
Definition moe_sorting_kernel.hpp:2704
void * p_sorted_weights
Definition moe_sorting_kernel.hpp:2712
mdiv topk_mdiv
Definition moe_sorting_kernel.hpp:2719
mdiv unit_size_mdiv
Definition moe_sorting_kernel.hpp:2718
Definition moe_sorting_kernel.hpp:2688
typename Problem::WeightType WeightType
Definition moe_sorting_kernel.hpp:2692
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition moe_sorting_kernel.hpp:2801
static CK_TILE_HOST constexpr auto GetSmemSize(const Hargs &h)
Definition moe_sorting_kernel.hpp:2793
typename Problem::IndexType IndexType
Definition moe_sorting_kernel.hpp:2691
static CK_TILE_HOST constexpr auto GridSize(const Hargs &h)
Definition moe_sorting_kernel.hpp:2779
remove_cvref_t< Problem_ > Problem
Definition moe_sorting_kernel.hpp:2689
MoeSortingHostArgs Hargs
Definition moe_sorting_kernel.hpp:2700
static CK_TILE_HOST constexpr auto MakeKargs(const Hargs &h)
Definition moe_sorting_kernel.hpp:2733
static constexpr index_t OCCUPANCY
Definition moe_sorting_kernel.hpp:2696
static CK_TILE_HOST constexpr auto BlockSize(const Hargs &)
Definition moe_sorting_kernel.hpp:2790
typename Problem::MeshType MeshType
Definition moe_sorting_kernel.hpp:2693
MoeSortingHostArgs MoeSortingKargs
Definition moe_sorting_kernel.hpp:2698
static constexpr index_t kBlockSize
Definition moe_sorting_kernel.hpp:2695
static CK_TILE_HOST constexpr auto get_num_cu()
Definition moe_sorting_kernel.hpp:2767
Definition moe_sorting_kernel.hpp:2284
const void * p_local_expert_mask
Definition moe_sorting_kernel.hpp:2285
long_index_t moe_buf_bytes
Definition moe_sorting_kernel.hpp:2296
void * p_expert_mesh
Definition moe_sorting_kernel.hpp:2287
index_t num_experts
Definition moe_sorting_kernel.hpp:2293
index_t tokens
Definition moe_sorting_kernel.hpp:2292
index_t mesh_stride
Definition moe_sorting_kernel.hpp:2294
void * p_total_tokens_post_pad
Definition moe_sorting_kernel.hpp:2289
void * p_moe_buf
Definition moe_sorting_kernel.hpp:2291
void * p_expert_cumsum
Definition moe_sorting_kernel.hpp:2288
void * p_sorted_expert_ids
Definition moe_sorting_kernel.hpp:2290
const void * p_local_tokens
Definition moe_sorting_kernel.hpp:2286
mdiv unit_size_mdiv
Definition moe_sorting_kernel.hpp:2295
Definition moe_sorting_kernel.hpp:2270
typename Problem::IndexType IndexType
Definition moe_sorting_kernel.hpp:2273
static CK_TILE_HOST constexpr auto GridSize(const Hargs &h)
Definition moe_sorting_kernel.hpp:2339
remove_cvref_t< Problem_ > Problem
Definition moe_sorting_kernel.hpp:2271
static constexpr index_t OCCUPANCY
Definition moe_sorting_kernel.hpp:2278
static CK_TILE_HOST constexpr auto BlockSize(const Hargs &)
Definition moe_sorting_kernel.hpp:2349
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition moe_sorting_kernel.hpp:2359
static CK_TILE_HOST constexpr auto get_num_cu()
Definition moe_sorting_kernel.hpp:2327
static constexpr index_t kBlockSize
Definition moe_sorting_kernel.hpp:2277
static CK_TILE_HOST constexpr auto MakeKargs(const Hargs &h)
Definition moe_sorting_kernel.hpp:2299
MoeSortingHostArgs MoeSortingKargs
Definition moe_sorting_kernel.hpp:2280
typename Problem::MeshType MeshType
Definition moe_sorting_kernel.hpp:2275
typename Problem::WeightType WeightType
Definition moe_sorting_kernel.hpp:2274
static CK_TILE_HOST_DEVICE constexpr auto GetSmemSize()
Definition moe_sorting_kernel.hpp:2352
MoeSortingHostArgs Hargs
Definition moe_sorting_kernel.hpp:2282
Definition moe_sorting_kernel.hpp:2527
void * p_sorted_weights
Definition moe_sorting_kernel.hpp:2532
index_t mesh_stride
Definition moe_sorting_kernel.hpp:2538
const void * p_local_expert_mask
Definition moe_sorting_kernel.hpp:2529
void * p_expert_cumsum
Definition moe_sorting_kernel.hpp:2534
void * p_sorted_token_ids
Definition moe_sorting_kernel.hpp:2531
const void * p_local_tokens
Definition moe_sorting_kernel.hpp:2530
mdiv topk_mdiv
Definition moe_sorting_kernel.hpp:2539
const void * p_weights
Definition moe_sorting_kernel.hpp:2528
index_t tokens
Definition moe_sorting_kernel.hpp:2536
void * p_expert_mesh
Definition moe_sorting_kernel.hpp:2533
index_t num_experts
Definition moe_sorting_kernel.hpp:2537
Definition moe_sorting_kernel.hpp:2512
static constexpr index_t OCCUPANCY
Definition moe_sorting_kernel.hpp:2520
typename Problem::WeightType WeightType
Definition moe_sorting_kernel.hpp:2516
static CK_TILE_HOST constexpr auto GridSize(const Hargs &h)
Definition moe_sorting_kernel.hpp:2561
static constexpr index_t kBlockSize
Definition moe_sorting_kernel.hpp:2519
static CK_TILE_HOST constexpr auto BlockSize(const Hargs &)
Definition moe_sorting_kernel.hpp:2563
remove_cvref_t< Problem_ > Problem
Definition moe_sorting_kernel.hpp:2513
typename Problem::MeshType MeshType
Definition moe_sorting_kernel.hpp:2517
MoeSortingHostArgs Hargs
Definition moe_sorting_kernel.hpp:2524
static CK_TILE_HOST_DEVICE constexpr auto GetSmemSize()
Definition moe_sorting_kernel.hpp:2566
MoeSortingHostArgs MoeSortingKargs
Definition moe_sorting_kernel.hpp:2522
typename Problem::IndexType IndexType
Definition moe_sorting_kernel.hpp:2515
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition moe_sorting_kernel.hpp:2571
static CK_TILE_HOST constexpr auto MakeKargs(const Hargs &h)
Definition moe_sorting_kernel.hpp:2542
Definition magic_div.hpp:186
CK_TILE_HOST_DEVICE void divmod(uint32_t dividend_, uint32_t &quotient_, uint32_t &remainder_) const
Definition magic_div.hpp:218
uint32_t divisor
Definition magic_div.hpp:188
CK_TILE_HOST_DEVICE uint32_t div(uint32_t dividend_) const
Definition magic_div.hpp:212
Definition coordinate_transform.hpp:1392
Definition tile/core/utility/functional.hpp:43