fmha_batch_prefill_kernel.hpp Source File

fmha_batch_prefill_kernel.hpp Source File#

Composable Kernel: fmha_batch_prefill_kernel.hpp Source File
fmha_batch_prefill_kernel.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10
11#include <string>
12#include <type_traits>
13#include <utility>
14#include <variant>
15
16// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
17// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
18// S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
19// P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
20// O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
21
22namespace ck_tile {
23
24template <typename FmhaPipeline_, typename EpiloguePipeline_>
26{
29 static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
30
31 static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
32 static_assert(kBlockPerCu > 0);
33 static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
34
44
46
47 static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
48 static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
49 static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
50 static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
51 static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
52 static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
53 static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
54 static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
55 static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
56 static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
59 static constexpr bool kHasMask = FmhaMask::IsMasking;
60
61 static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
62
63 // clang-format off
64 template <typename T> struct t2s;
65 template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
66 template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
67 template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
68 template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
69 template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
70 // clang-format on
71
72 CK_TILE_HOST static std::string GetName()
73 {
74 // sync with generate.py
75 // clang-format off
76 using bfs = typename FmhaPipeline::BlockFmhaShape;
77 using g0br = typename bfs::Gemm0BlockWarps;
78 using g1br = typename bfs::Gemm1BlockWarps;
79 using g0wt = typename bfs::Gemm0WarpTile;
80 using g1wt = typename bfs::Gemm1WarpTile;
81 #define _SS_ std::string
82 #define _TS_ std::to_string
83 auto pn = [&] () {
84 std::string n;
85 if (kPadSeqLenQ) n += "s";
86 if (kPadSeqLenK) n += "sk";
87 if (kPadHeadDimQ) n += "d";
88 if (kPadHeadDimV) n += "dv";
89 return n.empty() ? n : std::string("p") + n; }();
90 return
91 _SS_("fmha_batch_prefill_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
92 "_" + (kIsGroupMode ? "group" : "batch") + "_"
93 "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
94 _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
95 "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
96 "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
97 "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
98 "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
99 (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
100 "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
102 (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" );
103 #undef _SS_
104 #undef _TS_
105 // clang-format on
106 }
107
108 template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
109 // arg
111 {
112 };
113
114 // kargs use aggregate initializer, so no constructor will provided
115 // use inheritance to minimize karg size
116 // user need to use MakeKargs() function to create kargs.
118 {
119 const void* q_ptr;
120 const void* k_ptr;
121 const void* v_ptr;
122 void* o_ptr;
123
128
130 // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
131 // if this param is larger than 1, indicate MQA/GQA case
133
137#if 0 // we assume page_block_size=1 for now
138 const int32_t* kv_last_page_lens;
140#else
141 static constexpr ck_tile::index_t page_block_size = 1;
142#endif
143
144 float scale_s;
145
150
155 };
156
158 {
160
161 void init_logits_soft_cap(float logits_soft_cap_)
162 {
163 if(0 < logits_soft_cap_)
164 {
165 logits_soft_cap = logits_soft_cap_;
167 }
168 else
169 {
170 logits_soft_cap = 0.f;
172 }
173 }
174
177 };
178
185
190
192 {
193 // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
194 const void* alibi_slope_ptr;
195 ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
196 };
197
199 {
200 // ck_tile::index_t window_size_left, window_size_right;
203 };
204
206 {
207 float scale_p;
208 float scale_o;
209 };
210
217
231
233 {
234 void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
235 {
236 float p_undrop = 1.0 - p_drop;
238 uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
239 rp_undrop = 1.0 / p_undrop;
240
241 this->drop_seed.val = seed;
242 this->drop_offset.val = offset;
244 }
245
246 void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
247 {
248 float p_undrop = 1.0 - p_drop;
250 uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
251 rp_undrop = 1.0 / p_undrop;
252
253 this->drop_seed.ptr = seed_ptr;
254 this->drop_offset.ptr = offset_ptr;
255 this->is_drop_seed_offset_from_host = false;
256 }
257
258 float rp_undrop = 1;
259 uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
260 bool is_store_randval = false;
261 void* rand_val_ptr = nullptr;
262
265 };
266
271
274 std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
275 FmhaFwdBatchModeBiasKargs,
276 std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
277 FmhaFwdAlibiKargs,
278 FmhaFwdEmptyKargs<0>>>,
279 std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
280 std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
281 std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
282 std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
283 std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
284 {
289 };
290
293 std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
294 FmhaFwdCommonBiasKargs,
295 std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
296 FmhaFwdAlibiKargs,
297 FmhaFwdEmptyKargs<0>>>,
298 std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
299 std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
300 std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
301 std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
302 std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
303 {
307 };
308
309 using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
310
317
318 template <bool Cond = !kIsGroupMode>
319 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
320 MakeKargs(const void* q_ptr,
321 const void* k_ptr,
322 const void* v_ptr,
323 const void* bias_ptr,
324 void* rand_val_ptr,
325 void* lse_ptr,
326 void* o_ptr,
327 ck_tile::index_t seqlen_q,
328 ck_tile::index_t hdim_q,
329 ck_tile::index_t hdim_v,
330 ck_tile::index_t num_head_q,
331 ck_tile::index_t nhead_ratio_qk,
332 int32_t num_total_pages,
333 const void* kv_indptr,
334 const void* kv_page_indices,
335#if 0 // we assume page_block_size=1 for now
336 const void* kv_last_page_lens,
337 ck_tile::index_t page_block_size,
338#endif
339 float scale_s,
340 float scale_p,
341 float scale_o,
342 float logits_soft_cap,
343 ck_tile::index_t stride_q,
344 ck_tile::index_t stride_k,
345 ck_tile::index_t stride_v,
346 ck_tile::index_t stride_bias,
347 ck_tile::index_t stride_randval,
348 ck_tile::index_t stride_o,
349 ck_tile::index_t nhead_stride_q,
350 ck_tile::index_t nhead_stride_k,
351 ck_tile::index_t nhead_stride_v,
352 ck_tile::index_t nhead_stride_bias,
353 ck_tile::index_t nhead_stride_randval,
354 ck_tile::index_t nhead_stride_lse,
355 ck_tile::index_t nhead_stride_o,
356 ck_tile::index_t batch_stride_q,
357 ck_tile::index_t batch_stride_k,
358 ck_tile::index_t batch_stride_v,
359 ck_tile::index_t batch_stride_bias,
360 ck_tile::index_t batch_stride_randval,
361 ck_tile::index_t batch_stride_lse,
362 ck_tile::index_t batch_stride_o,
363 ck_tile::index_t window_size_left,
364 ck_tile::index_t window_size_right,
365 ck_tile::index_t mask_type,
366 float p_drop,
367 bool s_randval,
368 std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
369 drop_seed_offset)
370 {
371 Kargs kargs{{q_ptr,
372 k_ptr,
373 v_ptr,
374 o_ptr,
375 seqlen_q,
376 -1,
377 hdim_q,
378 hdim_v,
379 num_head_q,
380 nhead_ratio_qk,
381 num_total_pages,
382 reinterpret_cast<const int32_t*>(kv_indptr),
383 reinterpret_cast<const int32_t*>(kv_page_indices),
384#if 0 // we assume page_block_size=1 for now
385 reinterpret_cast<const int32_t*>(kv_last_page_lens),
386 page_block_size,
387#endif
389 static_cast<float>(scale_s * ck_tile::log2e_v<>),
390#else
391 scale_s,
392#endif
393 stride_q,
394 stride_k,
395 stride_v,
396 stride_o,
397 nhead_stride_q,
398 nhead_stride_k,
399 nhead_stride_v,
400 nhead_stride_o}, // args for common karg
401 {}, // placeholder for bias
402 {}, // placeholder for mask
403 {}, // placeholder for lse
404 {}, // placeholder for fp8_static_quant args
405 {}, // placeholder for dropout
406 {}, // placeholder for logits_soft_cap
407 batch_stride_q,
408 batch_stride_k,
409 batch_stride_v,
410 batch_stride_o};
411
413 {
414 kargs.bias_ptr = bias_ptr;
415 kargs.stride_bias = stride_bias;
416 kargs.nhead_stride_bias = nhead_stride_bias;
417 kargs.batch_stride_bias = batch_stride_bias;
418 }
419 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
420 {
421 kargs.alibi_slope_ptr = bias_ptr;
422 kargs.alibi_slope_stride = stride_bias;
423 }
424 if constexpr(kHasMask)
425 {
426 kargs.window_size_left = window_size_left;
427 kargs.window_size_right = window_size_right;
428 kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
429 }
430 if constexpr(kStoreLSE)
431 {
432 kargs.lse_ptr = lse_ptr;
433 kargs.nhead_stride_lse = nhead_stride_lse;
434 kargs.batch_stride_lse = batch_stride_lse;
435 }
436 if constexpr(kDoFp8StaticQuant)
437 {
438 kargs.scale_p = scale_p;
439 kargs.scale_o = scale_o;
440 }
441 if constexpr(kHasDropout)
442 {
443 if(drop_seed_offset.index() == 0) // seed & offset come from host
444 {
445 const auto& [seed, offset] = std::get<0>(drop_seed_offset);
446 kargs.init_dropout(p_drop, seed, offset);
447 }
448 else // seed & offset come from device
449 {
450 const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
451 kargs.init_dropout(p_drop,
452 reinterpret_cast<const uint64_t*>(seed_ptr),
453 reinterpret_cast<const uint64_t*>(offset_ptr));
454 }
455
456 kargs.rand_val_ptr = rand_val_ptr;
457 kargs.stride_randval = stride_randval;
458 kargs.nhead_stride_randval = nhead_stride_randval;
459 kargs.batch_stride_randval = batch_stride_randval;
460 kargs.is_store_randval = s_randval;
461 }
462 if constexpr(kHasLogitsSoftCap)
463 {
464 kargs.init_logits_soft_cap(logits_soft_cap);
465 }
466
467 return kargs;
468 }
469
470 template <bool Cond = kIsGroupMode>
471 CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
472 MakeKargs(const void* q_ptr,
473 const void* k_ptr,
474 const void* v_ptr,
475 const void* bias_ptr,
476 void* rand_val_ptr,
477 void* lse_ptr,
478 void* o_ptr,
479 const void* seqstart_q_ptr,
480 ck_tile::index_t hdim_q,
481 ck_tile::index_t hdim_v,
482 ck_tile::index_t num_head_q,
483 ck_tile::index_t nhead_ratio_qk,
484 int32_t num_total_pages,
485 const void* kv_indptr,
486 const void* kv_page_indices,
487#if 0 // we assume page_block_size=1 for now
488 const void* kv_last_page_lens,
489 ck_tile::index_t page_block_size,
490#endif
491 float scale_s,
492 float scale_p,
493 float scale_o,
494 float logits_soft_cap,
495 ck_tile::index_t stride_q,
496 ck_tile::index_t stride_k,
497 ck_tile::index_t stride_v,
498 ck_tile::index_t stride_bias,
499 ck_tile::index_t stride_randval,
500 ck_tile::index_t stride_o,
501 ck_tile::index_t nhead_stride_q,
502 ck_tile::index_t nhead_stride_k,
503 ck_tile::index_t nhead_stride_v,
504 ck_tile::index_t nhead_stride_bias,
505 ck_tile::index_t nhead_stride_randval,
506 ck_tile::index_t nhead_stride_lse,
507 ck_tile::index_t nhead_stride_o,
508 ck_tile::index_t batch_stride_k,
509 ck_tile::index_t batch_stride_v,
510 ck_tile::index_t window_size_left,
511 ck_tile::index_t window_size_right,
512 ck_tile::index_t mask_type,
513 float p_drop,
514 bool s_randval,
515 std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
516 drop_seed_offset)
517 {
518 Kargs kargs{{q_ptr,
519 k_ptr,
520 v_ptr,
521 o_ptr,
522 -1, // seqlen will be updated by another pointer
523 -1, //
524 hdim_q,
525 hdim_v,
526 num_head_q,
527 nhead_ratio_qk,
528 num_total_pages,
529 reinterpret_cast<const int32_t*>(kv_indptr),
530 reinterpret_cast<const int32_t*>(kv_page_indices),
531#if 0 // we assume page_block_size=1 for now
532 reinterpret_cast<const int32_t*>(kv_last_page_lens),
533 page_block_size,
534#endif
536 static_cast<float>(scale_s * ck_tile::log2e_v<>),
537#else
538 scale_s,
539#endif
540 stride_q,
541 stride_k,
542 stride_v,
543 stride_o,
544 nhead_stride_q,
545 nhead_stride_k,
546 nhead_stride_v,
547 nhead_stride_o}, // args for common karg
548 {}, // placeholder for bias
549 {}, // placeholder for mask
550 {}, // placeholder for lse
551 {}, // placeholder for fp8_static_quant args
552 {}, // placeholder for dropout
553 {}, // placeholder for logits_soft_cap
554 reinterpret_cast<const int32_t*>(seqstart_q_ptr),
555 batch_stride_k,
556 batch_stride_v};
557
559 {
560 kargs.bias_ptr = bias_ptr;
561 kargs.stride_bias = stride_bias;
562 kargs.nhead_stride_bias = nhead_stride_bias;
563 }
564 else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
565 {
566 kargs.alibi_slope_ptr = bias_ptr;
567 kargs.alibi_slope_stride = stride_bias;
568 }
569 if constexpr(kHasMask)
570 {
571 kargs.window_size_left = window_size_left;
572 kargs.window_size_right = window_size_right;
573 kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
574 }
575 if constexpr(kStoreLSE)
576 {
577 kargs.lse_ptr = lse_ptr;
578 kargs.nhead_stride_lse = nhead_stride_lse;
579 }
580 if constexpr(kDoFp8StaticQuant)
581 {
582 kargs.scale_p = scale_p;
583 kargs.scale_o = scale_o;
584 }
585 if constexpr(kHasDropout)
586 {
587 if(drop_seed_offset.index() == 0) // seed & offset come from host
588 {
589 const auto& [seed, offset] = std::get<0>(drop_seed_offset);
590 kargs.init_dropout(p_drop, seed, offset);
591 }
592 else // seed & offset come from device
593 {
594 const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
595 kargs.init_dropout(p_drop,
596 reinterpret_cast<const uint64_t*>(seed_ptr),
597 reinterpret_cast<const uint64_t*>(offset_ptr));
598 }
599
600 kargs.rand_val_ptr = rand_val_ptr;
601 kargs.stride_randval = stride_randval;
602 kargs.nhead_stride_randval = nhead_stride_randval;
603 kargs.is_store_randval = s_randval;
604 }
605 if constexpr(kHasLogitsSoftCap)
606 {
607 kargs.init_logits_soft_cap(logits_soft_cap);
608 }
609
610 return kargs;
611 }
612
613 CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
614 ck_tile::index_t nhead_,
615 ck_tile::index_t seqlen_q_,
616 ck_tile::index_t hdim_v_)
617 {
618 if constexpr(kIsGroupMode)
619 {
620 // TODO: this may need tuning
621 return dim3(nhead_,
622 batch_size_,
623 ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
624 ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
625 }
626 else
627 {
628 // TODO: this may need tuning
629 return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
630 ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
631 nhead_,
632 batch_size_);
633 }
634 }
635
636 CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
637 {
638 if constexpr(kIsGroupMode)
639 {
640 // const index_t num_tile_m0 = seqlen_q / kM0;
641 const index_t num_tile_n1 =
642 ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
643
644 const index_t i_block = blockIdx.z;
645 const index_t i_nhead = blockIdx.x;
646 const index_t i_batch = blockIdx.y;
647
648 const auto f = [](index_t dividend, index_t divisor) {
649 index_t quotient = dividend / divisor;
650 index_t modulus = dividend - quotient * divisor;
651 return ck_tile::make_tuple(quotient, modulus);
652 };
653
654 const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
655 if constexpr(kHasMask)
656 {
657 // assume that num_tile_n1 is always 1
658 return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
659 }
660 else
661 {
662 return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
663 }
664 }
665 else
666 {
667 // const index_t num_tile_m0 = seqlen_q / kM0;
668 const index_t num_tile_n1 =
669 ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
670
671 const index_t i_block = blockIdx.x;
672 const index_t i_nhead = blockIdx.y;
673 const index_t i_batch = blockIdx.z;
674
675 const auto f = [](index_t dividend, index_t divisor) {
676 index_t quotient = dividend / divisor;
677 index_t modulus = dividend - quotient * divisor;
678 return ck_tile::make_tuple(quotient, modulus);
679 };
680
681 const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
682
683 if constexpr(kHasMask)
684 {
685 // assume that num_tile_n1 is always 1
686 return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
687 }
688 else
689 {
690 return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
691 }
692 }
693 }
694
696 {
697 if(is_wave32())
698 {
699 return dim3(kBlockSize / 2);
700 }
701 else
702 {
703 return dim3(kBlockSize);
704 }
705 }
706
708 {
709 return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
710 }
711
713 {
714 // allocate LDS
715 __shared__ char smem_ptr[GetSmemSize()];
716
717 // divide problem
718 const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
719
720 const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
721 const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
722
723 long_index_t batch_offset_q = 0;
724 long_index_t batch_offset_bias = 0;
725 long_index_t batch_offset_randval = 0;
726 long_index_t batch_offset_lse = 0;
727 long_index_t batch_offset_o = 0;
728
729 const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch];
730#if 0 // we assume page_block_size=1 for now
731 const int32_t last_page_len = kargs.kv_last_page_lens[i_batch];
732#endif
733 if constexpr(kIsGroupMode)
734 {
735 // get starting offset for each batch
736 const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
737
738 batch_offset_q = query_start * kargs.stride_q;
739
740 kargs.kv_page_indices += kargs.kv_indptr[i_batch];
741
743 {
744 batch_offset_bias = query_start * kargs.stride_bias;
745 }
746 if constexpr(kStoreLSE)
747 {
748 batch_offset_lse = query_start;
749 }
750 if constexpr(kHasDropout)
751 {
752 batch_offset_randval = query_start * kargs.stride_randval;
753 }
754 batch_offset_o = query_start * kargs.stride_o;
755
756 // get real # queries & # keys under group mode
757 kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - query_start;
758
759 // # of required blocks is different in each groups, terminate unnecessary blocks
760 // earlier
761 if(kargs.seqlen_q <= i_m0)
762 {
763 return;
764 }
765
766#if 0 // we assume page_block_size=1 for now
767 kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
768#else
769 kargs.seqlen_k = num_page_blocks;
770#endif
771 }
772 else
773 {
774 batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
775
776 kargs.kv_page_indices += kargs.kv_indptr[i_batch];
777
779 {
780 batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
781 }
782 if constexpr(kStoreLSE)
783 {
784 batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
785 }
786 if constexpr(kHasDropout)
787 {
788 batch_offset_randval =
789 static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
790 }
791 batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
792
793#if 0 // we assume page_block_size=1 for now
794 kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
795#else
796 kargs.seqlen_k = num_page_blocks;
797#endif
798 }
799
800 // for simplicity, batch stride we just modify the pointer
801 const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
802 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
803 batch_offset_q;
804 const KDataType* k_ptr =
805 reinterpret_cast<const KDataType*>(kargs.k_ptr) +
806 static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k;
807 const VDataType* v_ptr =
808 reinterpret_cast<const VDataType*>(kargs.v_ptr) +
809 static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v;
810 ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
811 static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
812 batch_offset_o;
813
814 // Q/K/V DRAM and DRAM window
815 const auto q_dram = [&]() {
817 q_ptr,
818 make_tuple(kargs.seqlen_q, kargs.hdim_q),
819 make_tuple(kargs.stride_q, 1),
821 number<1>{});
822 if constexpr(FmhaPipeline::kQLoadOnce)
823 {
824 return pad_tensor_view(
825 q_dram_naive,
828 }
829 else
830 {
831 return pad_tensor_view(
832 q_dram_naive,
835 }
836 }();
837 const auto k_dram = [&]() {
839 k_ptr,
840 make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_q),
841 make_tuple(kargs.stride_k, 1),
843 number<1>{});
844
845 constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
846 return pad_tensor_view(
847 k_dram_naive,
850 }();
851 const auto v_dram = [&]() {
852 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
853 {
855 v_ptr,
856 make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_v),
857 make_tuple(kargs.stride_v, 1),
859 number<1>{});
860
861 const auto v_dram_transposed = transform_tensor_view(
862 v_dram_naive,
864 make_pass_through_transform(kargs.hdim_v),
865 make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)),
868
869 constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
870 return pad_tensor_view(
871 v_dram_transposed,
874 }
875 else
876 {
878 v_ptr,
879 make_tuple(kargs.hdim_v, kargs.num_total_pages * kargs.page_block_size),
880 make_tuple(kargs.stride_v, 1),
882 number<1>{});
883
884 constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
885 return pad_tensor_view(
886 v_dram_naive,
889 }
890 }();
891
892 auto q_dram_window = make_tile_window(
893 q_dram,
894 [&]() {
895 if constexpr(FmhaPipeline::kQLoadOnce)
898 else
900 }(),
901 {i_m0, 0});
902
903 auto k_dram_window = make_tile_window(
905
906 auto v_dram_window =
907 make_tile_window(v_dram,
909 {i_n1, 0});
912 const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
913 constexpr auto bias_dram_window_lengths =
916 {
917 const BiasDataType* bias_ptr =
918 reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
919 static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
920 batch_offset_bias;
921
922 const auto bias_dram = [&]() {
923 const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
924 bias_ptr,
925 make_tuple(kargs.seqlen_q, kargs.seqlen_k),
926 make_tuple(kargs.stride_bias, 1),
928 number<1>{});
929
930 return pad_tensor_view(bias_dram_naive,
931 bias_dram_window_lengths,
933 }();
934
935 return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
936 }
937 else
938 {
939 return make_null_tile_window(bias_dram_window_lengths);
940 }
941 }();
942
943 // lse
944 auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
945 constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
946 if constexpr(kStoreLSE)
947 {
948 LSEDataType* lse_ptr =
949 reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
950 static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
951
952 const auto lse_dram = [&]() {
953 const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
954 lse_ptr,
955 make_tuple(kargs.seqlen_q),
956 make_tuple(1),
957 number<1>{},
958 number<1>{});
959
960 return pad_tensor_view(
961 lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
962 }();
963
964 return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
965 }
966 else
967 {
968 return make_null_tile_window(lse_dram_window_lengths);
969 }
970 }();
971
972 auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
973 if constexpr(kHasDropout)
974 {
975 return BlockDropout{i_batch_,
976 i_nhead_,
977 kargs.num_head_q,
978 kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
979 : *kargs.drop_seed.ptr,
980 kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
981 : *kargs.drop_offset.ptr,
982 kargs.rp_undrop,
983 kargs.p_undrop_in_uint8_t,
984 kargs.is_store_randval};
985 }
986 else
987 {
988 return NullBlockDropout{};
989 };
990 }();
991
992 auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
993 constexpr auto randval_dram_window_lengths =
995 if constexpr(kHasDropout)
996 {
997 RandValOutputDataType* rand_val_ptr =
998 reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
999 static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
1000 batch_offset_randval;
1001
1002 const auto randval_dram = [&]() {
1003 const auto randval_dram_naive =
1005 rand_val_ptr,
1006 make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1007 make_tuple(kargs.stride_randval, 1),
1008 number<1>{},
1009 number<1>{});
1010
1011 return pad_tensor_view(randval_dram_naive,
1012 randval_dram_window_lengths,
1014 }();
1015
1016 return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0});
1017 }
1018 else
1019 {
1020 return make_null_tile_window(randval_dram_window_lengths);
1021 }
1022 }();
1023
1024 FmhaMask mask = [&]() {
1025 if constexpr(kHasMask)
1027 kargs.window_size_left,
1028 kargs.window_size_right,
1029 kargs.seqlen_q,
1030 kargs.seqlen_k,
1032 else
1033 return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1034 }();
1035
1036 // WA i_batch capture structure binding before c++20
1037 auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1039 {
1040 // data loading, shared by entire wg
1041 // TODO: how to use s_read?
1042 SaccDataType slope =
1043 *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1044 i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1045#if CK_TILE_FMHA_FWD_FAST_EXP2
1046 slope *= ck_tile::log2e_v<>;
1047#endif
1048 if constexpr(kHasMask)
1049 {
1051 kargs.window_size_left,
1052 kargs.window_size_right,
1053 kargs.seqlen_q,
1054 kargs.seqlen_k,
1055 kargs.mask_type);
1056 }
1057 else
1058 {
1060 slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1061 }
1062 }
1063 else
1064 {
1066 }
1067 }();
1068
1069 AttentionVariant variant;
1070 const auto variant_params = [&] {
1071 if constexpr(kHasLogitsSoftCap)
1072 {
1074 mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1075 }
1076 else
1077 {
1078 return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1079 }
1080 }();
1081
1082 BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
1083
1084 auto o_acc_tile = [&]() {
1085 if constexpr(kDoFp8StaticQuant)
1086 {
1087 return FmhaPipeline{}(
1088 q_dram_window,
1089 identity{}, // q_element_func
1090 k_dram_window,
1091 identity{}, // k_element_func
1092 v_dram_window,
1093 identity{}, // v_element_func
1094 bias_dram_window,
1095 identity{}, // bias_element_func
1096 randval_dram_window,
1097 lse_dram_window,
1098 identity{}, // lse_element_func
1099 identity{}, // s_acc_element_func
1100 scales{kargs.scale_p}, // p_compute_element_func
1101 composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
1102 mask,
1103 position_encoding,
1104 kargs.scale_s,
1105 variant,
1106 variant_params,
1107 block_indices,
1108 smem_ptr,
1109 kargs.kv_page_indices,
1110 kargs.stride_k,
1111 kargs.stride_v,
1112 dropout);
1113 }
1114 else
1115 {
1116 return FmhaPipeline{}(q_dram_window,
1117 k_dram_window,
1118 v_dram_window,
1119 bias_dram_window,
1120 randval_dram_window,
1121 lse_dram_window,
1122 mask,
1123 position_encoding,
1124 kargs.scale_s,
1125 variant,
1126 variant_params,
1127 block_indices,
1128 smem_ptr,
1129 kargs.kv_page_indices,
1130 kargs.stride_k,
1131 kargs.stride_v,
1132 dropout);
1133 }
1134 }();
1135
1136 // O DRAM and O DRAM window
1137 auto o_dram = [&]() {
1139 o_ptr,
1140 make_tuple(kargs.seqlen_q, kargs.hdim_v),
1141 make_tuple(kargs.stride_o, 1),
1143 number<1>{});
1144
1145 return pad_tensor_view(
1146 o_dram_naive,
1149 }();
1150
1151 auto o_dram_window =
1152 make_tile_window(o_dram,
1154 {i_m0, i_n1});
1155
1156 EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
1157 }
1158};
1159
1160} // namespace ck_tile
#define _TS_
#define _SS_
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_FMHA_FWD_FAST_EXP2
Definition config.hpp:234
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tensor_view.hpp:471
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_generic_attention_mask_from_lr_window(index_t left_size, index_t right_size, index_t y_total, index_t x_total, bool is_top_left=true)
Definition block_masking.hpp:632
@ ALIBI
Definition block_attention_bias_enum.hpp:15
@ NO_BIAS
Definition block_attention_bias_enum.hpp:13
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
bfloat16_t bf16_t
Definition bfloat16.hpp:113
_Float16 fp16_t
Definition half.hpp:110
_BitInt(8) fp8_t
Definition float8.hpp:204
constexpr T log2e_v
Definition tile/core/numeric/math.hpp:488
int64_t long_index_t
Definition integer.hpp:11
CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, index_t window_left_size, index_t window_right_size, index_t y_total, index_t x_total, GenericAttentionMaskEnum mask_enum)
Definition block_position_encoding.hpp:148
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition tensor_view.hpp:530
int32_t int32_t
Definition integer.hpp:10
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths &window_lengths)
Definition null_tile_window.hpp:66
unsigned _BitInt(8) bf8_t
Definition float8.hpp:206
GenericAttentionMaskEnum
Definition block_masking.hpp:11
@ MASK_FROM_TOP_LEFT
Definition block_masking.hpp:15
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
int32_t index_t
Definition integer.hpp:9
@ FROM_BOTTOM_RIGHT
Definition block_position_encoding.hpp:43
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_view.hpp:511
__host__ __device__ composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
FIXME: create macro to replace 'host device' and nothing more.
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
CK_TILE_HOST bool is_wave32()
Definition arch.hpp:72
unsigned char uint8_t
Definition stdint.h:124
unsigned __int64 uint64_t
Definition stdint.h:136
Definition block_position_encoding.hpp:48
Definition block_attention_bias_enum.hpp:19
Definition block_dropout.hpp:53
Definition block_position_encoding.hpp:137
Definition fmha_batch_prefill_kernel.hpp:312
ck_tile::index_t kv_head_idx
Definition fmha_batch_prefill_kernel.hpp:315
ck_tile::index_t qo_head_idx
Definition fmha_batch_prefill_kernel.hpp:314
ck_tile::index_t batch_idx
Definition fmha_batch_prefill_kernel.hpp:313
Definition fmha_batch_prefill_kernel.hpp:192
ck_tile::index_t alibi_slope_stride
Definition fmha_batch_prefill_kernel.hpp:195
const void * alibi_slope_ptr
Definition fmha_batch_prefill_kernel.hpp:194
ck_tile::index_t batch_stride_bias
Definition fmha_batch_prefill_kernel.hpp:188
ck_tile::index_t batch_stride_randval
Definition fmha_batch_prefill_kernel.hpp:269
ck_tile::index_t batch_stride_o
Definition fmha_batch_prefill_kernel.hpp:288
ck_tile::index_t batch_stride_v
Definition fmha_batch_prefill_kernel.hpp:287
ck_tile::index_t batch_stride_q
Definition fmha_batch_prefill_kernel.hpp:285
ck_tile::index_t batch_stride_k
Definition fmha_batch_prefill_kernel.hpp:286
ck_tile::index_t nhead_stride_bias
Definition fmha_batch_prefill_kernel.hpp:183
ck_tile::index_t stride_bias
Definition fmha_batch_prefill_kernel.hpp:182
const void * bias_ptr
Definition fmha_batch_prefill_kernel.hpp:181
ck_tile::index_t stride_randval
Definition fmha_batch_prefill_kernel.hpp:263
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr)
Definition fmha_batch_prefill_kernel.hpp:246
ck_tile::index_t nhead_stride_randval
Definition fmha_batch_prefill_kernel.hpp:264
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
Definition fmha_batch_prefill_kernel.hpp:234
void * rand_val_ptr
Definition fmha_batch_prefill_kernel.hpp:261
float rp_undrop
Definition fmha_batch_prefill_kernel.hpp:258
bool is_store_randval
Definition fmha_batch_prefill_kernel.hpp:260
uint8_t p_undrop_in_uint8_t
Definition fmha_batch_prefill_kernel.hpp:259
ck_tile::index_t stride_q
Definition fmha_batch_prefill_kernel.hpp:146
ck_tile::index_t stride_v
Definition fmha_batch_prefill_kernel.hpp:148
int32_t num_total_pages
Definition fmha_batch_prefill_kernel.hpp:134
float scale_s
Definition fmha_batch_prefill_kernel.hpp:144
ck_tile::index_t seqlen_q
Definition fmha_batch_prefill_kernel.hpp:124
ck_tile::index_t stride_k
Definition fmha_batch_prefill_kernel.hpp:147
ck_tile::index_t nhead_stride_o
Definition fmha_batch_prefill_kernel.hpp:154
ck_tile::index_t nhead_stride_k
Definition fmha_batch_prefill_kernel.hpp:152
ck_tile::index_t nhead_ratio_qk
Definition fmha_batch_prefill_kernel.hpp:132
ck_tile::index_t nhead_stride_v
Definition fmha_batch_prefill_kernel.hpp:153
ck_tile::index_t nhead_stride_q
Definition fmha_batch_prefill_kernel.hpp:151
const int32_t * kv_page_indices
Definition fmha_batch_prefill_kernel.hpp:136
const void * v_ptr
Definition fmha_batch_prefill_kernel.hpp:121
const int32_t * kv_indptr
Definition fmha_batch_prefill_kernel.hpp:135
void * o_ptr
Definition fmha_batch_prefill_kernel.hpp:122
ck_tile::index_t seqlen_k
Definition fmha_batch_prefill_kernel.hpp:125
ck_tile::index_t stride_o
Definition fmha_batch_prefill_kernel.hpp:149
ck_tile::index_t hdim_v
Definition fmha_batch_prefill_kernel.hpp:127
ck_tile::index_t num_head_q
Definition fmha_batch_prefill_kernel.hpp:129
static constexpr ck_tile::index_t page_block_size
Definition fmha_batch_prefill_kernel.hpp:141
const void * k_ptr
Definition fmha_batch_prefill_kernel.hpp:120
ck_tile::index_t hdim_q
Definition fmha_batch_prefill_kernel.hpp:126
const void * q_ptr
Definition fmha_batch_prefill_kernel.hpp:119
ck_tile::index_t batch_stride_lse
Definition fmha_batch_prefill_kernel.hpp:215
ck_tile::index_t nhead_stride_lse
Definition fmha_batch_prefill_kernel.hpp:214
void * lse_ptr
Definition fmha_batch_prefill_kernel.hpp:213
bool is_drop_seed_offset_from_host
Definition fmha_batch_prefill_kernel.hpp:229
ValueOrPointer< uint64_t > drop_seed
Definition fmha_batch_prefill_kernel.hpp:227
ValueOrPointer< uint64_t > drop_offset
Definition fmha_batch_prefill_kernel.hpp:228
Definition fmha_batch_prefill_kernel.hpp:111
float scale_p
Definition fmha_batch_prefill_kernel.hpp:207
float scale_o
Definition fmha_batch_prefill_kernel.hpp:208
ck_tile::index_t batch_stride_v
Definition fmha_batch_prefill_kernel.hpp:306
ck_tile::index_t batch_stride_k
Definition fmha_batch_prefill_kernel.hpp:305
const int32_t * seqstart_q_ptr
Definition fmha_batch_prefill_kernel.hpp:304
float logits_soft_cap_rcp
Definition fmha_batch_prefill_kernel.hpp:176
void init_logits_soft_cap(float logits_soft_cap_)
Definition fmha_batch_prefill_kernel.hpp:161
float logits_soft_cap
Definition fmha_batch_prefill_kernel.hpp:175
Definition fmha_batch_prefill_kernel.hpp:199
ck_tile::index_t window_size_right
Definition fmha_batch_prefill_kernel.hpp:201
ck_tile::index_t window_size_left
Definition fmha_batch_prefill_kernel.hpp:201
ck_tile::GenericAttentionMaskEnum mask_type
Definition fmha_batch_prefill_kernel.hpp:202
static constexpr const char * name
Definition fmha_batch_prefill_kernel.hpp:67
static constexpr const char * name
Definition fmha_batch_prefill_kernel.hpp:69
static constexpr const char * name
Definition fmha_batch_prefill_kernel.hpp:66
static constexpr const char * name
Definition fmha_batch_prefill_kernel.hpp:68
static constexpr const char * name
Definition fmha_batch_prefill_kernel.hpp:65
Definition fmha_batch_prefill_kernel.hpp:64
Definition fmha_batch_prefill_kernel.hpp:26
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition fmha_batch_prefill_kernel.hpp:707
static constexpr bool kIsGroupMode
Definition fmha_batch_prefill_kernel.hpp:47
static constexpr ck_tile::index_t kBlockPerCu
Definition fmha_batch_prefill_kernel.hpp:31
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition fmha_batch_prefill_kernel.hpp:37
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition fmha_batch_prefill_kernel.hpp:27
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition fmha_batch_prefill_kernel.hpp:36
static constexpr bool kPadSeqLenQ
Definition fmha_batch_prefill_kernel.hpp:48
static constexpr bool kDoFp8StaticQuant
Definition fmha_batch_prefill_kernel.hpp:56
static constexpr bool kPadHeadDimV
Definition fmha_batch_prefill_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition fmha_batch_prefill_kernel.hpp:41
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition fmha_batch_prefill_kernel.hpp:39
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition fmha_batch_prefill_kernel.hpp:35
static constexpr bool kHasMask
Definition fmha_batch_prefill_kernel.hpp:59
static CK_TILE_HOST std::string GetName()
Definition fmha_batch_prefill_kernel.hpp:72
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition fmha_batch_prefill_kernel.hpp:38
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition fmha_batch_prefill_kernel.hpp:58
static CK_TILE_HOST dim3 BlockSize()
Definition fmha_batch_prefill_kernel.hpp:695
static constexpr bool kPadSeqLenK
Definition fmha_batch_prefill_kernel.hpp:49
static constexpr bool kHasLogitsSoftCap
Definition fmha_batch_prefill_kernel.hpp:52
static constexpr bool kHasDropout
Definition fmha_batch_prefill_kernel.hpp:55
static constexpr bool kStoreLSE
Definition fmha_batch_prefill_kernel.hpp:54
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition fmha_batch_prefill_kernel.hpp:43
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, int32_t num_total_pages, const void *kv_indptr, const void *kv_page_indices, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * > > drop_seed_offset)
Definition fmha_batch_prefill_kernel.hpp:320
static CK_TILE_DEVICE constexpr auto GetTileIndex(const Kargs &kargs)
Definition fmha_batch_prefill_kernel.hpp:636
static constexpr ck_tile::index_t kBlockSize
Definition fmha_batch_prefill_kernel.hpp:29
static CK_TILE_HOST constexpr auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_)
Definition fmha_batch_prefill_kernel.hpp:613
static constexpr auto BiasEnum
Definition fmha_batch_prefill_kernel.hpp:53
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition fmha_batch_prefill_kernel.hpp:45
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition fmha_batch_prefill_kernel.hpp:42
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition fmha_batch_prefill_kernel.hpp:57
static constexpr bool kUseAsyncCopy
Definition fmha_batch_prefill_kernel.hpp:61
static constexpr ck_tile::index_t kBlockPerCuInput
Definition fmha_batch_prefill_kernel.hpp:33
static constexpr bool kPadHeadDimQ
Definition fmha_batch_prefill_kernel.hpp:50
static CK_TILE_HOST constexpr std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, int32_t num_total_pages, const void *kv_indptr, const void *kv_page_indices, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * > > drop_seed_offset)
Definition fmha_batch_prefill_kernel.hpp:472
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition fmha_batch_prefill_kernel.hpp:712
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition fmha_batch_prefill_kernel.hpp:309
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition fmha_batch_prefill_kernel.hpp:28
Definition variants.hpp:63
Definition block_dropout.hpp:39
Definition variants.hpp:51
Definition tile/core/utility/functional.hpp:86
Definition coordinate_transform.hpp:1392
Definition unary_element_function.hpp:56
Definition tile/core/numeric/math.hpp:28
Definition tile/core/container/sequence.hpp:49