block_fmha_fwd_v3_pipeline.hpp Source File

block_fmha_fwd_v3_pipeline.hpp Source File#

Composable Kernel: block_fmha_fwd_v3_pipeline.hpp Source File
block_fmha_fwd_v3_pipeline.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10#define ENABLE_ASM_MARKER 1
11#if ENABLE_ASM_MARKER
12#define ASM_MARKER(marker) \
13 __builtin_amdgcn_sched_barrier(0); \
14 asm volatile("; [POYENC] " #marker); \
15 __builtin_amdgcn_sched_barrier(0);
16#else
17#define ASM_MARKER(marker)
18#endif
19
20#define ADD_SBARRIER_FOR_PHASE0 1
21#if !defined(CK_TILE_DISABLE_PACKED_FP32)
22#define CK_TILE_DISABLE_PACKED_FP32 0
23#endif
24
25#define WARP_ID 0
26#define LANE_ID 0
27
28#define ENABLE_DEBUG_STMTS 1
29#if ENABLE_DEBUG_STMTS
30#define DEBUG_STMTS \
31 if(get_block_1d_id() == 0 && get_warp_id() == WARP_ID && get_lane_id() == LANE_ID)
32#else
33#define DEBUG_STMTS if constexpr(false)
34#endif
35
36namespace ck_tile {
37
38template <typename PipelineProblem, bool kIsMasking>
40
41template <typename PipelineProblem>
42struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/true>
43{
44 template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
47 {
48 using namespace ck_tile;
49
50 if constexpr(WaveGroup == 0)
51 {
52 if constexpr(Phase == 0)
53 {
54 static_for<0, 8, 1>{}([&](auto) {
55 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
56 __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
57 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
58 });
59 }
60 else if constexpr(Phase == 1)
61 {
62 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
63 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
64 }
65 else if constexpr(Phase == 2)
66 {
67#if !CK_TILE_DISABLE_PACKED_FP32
68 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
69#endif
70 static_for<0, 8, 1>{}([&](auto) {
71 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
72 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
73 });
74 }
75 else if constexpr(Phase == 3)
76 {
77 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
78 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
79 }
80 }
81 else
82 {
83 if constexpr(Phase == 0)
84 {
85 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
86 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
87 }
88 else if constexpr(Phase == 1)
89 {
90 static_for<0, 8, 1>{}([&](auto) {
91 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
92 __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
93 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
94 });
95 }
96 else if constexpr(Phase == 2)
97 {
98 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
99 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
100 }
101 else if constexpr(Phase == 3)
102 {
103#if !CK_TILE_DISABLE_PACKED_FP32
104 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
105#endif
106 static_for<0, 8, 1>{}([&](auto) {
107 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
108 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
109 });
110 }
111 }
112 }
113};
114
115template <typename PipelineProblem>
116struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/false>
117{
118 template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
121 {
122 using namespace ck_tile;
123
124 if constexpr(WaveGroup == 0)
125 {
126 if constexpr(Phase == 0)
127 {
128 static_for<0, 8, 1>{}([&](auto) {
129 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
130 __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
131 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
132 });
133 }
134 else if constexpr(Phase == 1)
135 {
136 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
137 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
138 }
139 else if constexpr(Phase == 2)
140 {
141#if !CK_TILE_DISABLE_PACKED_FP32
142 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
143#endif
144 static_for<0, 8, 1>{}([&](auto) {
145 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
146 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
147 });
148 }
149 else if constexpr(Phase == 3)
150 {
151 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
152 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
153 }
154 }
155 else
156 {
157 if constexpr(Phase == 0)
158 {
159 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
160 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
161 }
162 else if constexpr(Phase == 1)
163 {
164 static_for<0, 8, 1>{}([&](auto) {
165 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
166 __builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
167 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
168 });
169 }
170 else if constexpr(Phase == 2)
171 {
172 __builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
173 __builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
174 }
175 else if constexpr(Phase == 3)
176 {
177#if !CK_TILE_DISABLE_PACKED_FP32
178 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
179#endif
180 static_for<0, 8, 1>{}([&](auto) {
181 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
182 __builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
183 });
184 }
185 }
186 }
187};
188
189namespace detail {
190CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c)
191{
192#if CK_TILE_DISABLE_PACKED_FP32
193 return a * b + c;
194#else
195 float result;
196 asm volatile("v_fma_f32 %[result], %[a], %[b], %[c]"
197 : [result] "=v"(result)
198 : [a] "v"(a), [b] "s"(b), [c] "v"(c));
199 return result;
200#endif
201}
202
203CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
204{
205 float result;
206 asm volatile("v_add_f32_e32 %[result], %[lhs], %[rhs]"
207 : [result] "=v"(result)
208 : [lhs] "v"(lhs), [rhs] "v"(rhs));
209 return result;
210}
211
212CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs)
213{
214 float result;
215 asm volatile("v_mul_f32_e32 %[result], %[lhs], %[rhs]"
216 : [result] "=v"(result)
217 : [lhs] "v"(lhs), [rhs] "v"(rhs));
218 return result;
219}
220
222{
223 fp16x2_t result;
224 asm volatile("v_cvt_pk_f16_f32 %[result], %[a], %[b]"
225 : [result] "=v"(result)
226 : [a] "v"(a), [b] "v"(b));
227 return result;
228}
229
231{
232 bf16x2_t result;
233 asm volatile("v_cvt_pk_bf16_f32 %[result], %[a], %[b]"
234 : [result] "=v"(result)
235 : [a] "v"(a), [b] "v"(b));
236 return result;
237}
238
240{
241 fp32x2_t result;
242 asm volatile("v_pk_mul_f32 %[result], %[lhs], %[rhs]"
243 : [result] "=v"(result)
244 : [lhs] "v"(lhs), [rhs] "v"(rhs));
245 return result;
246}
247} // namespace detail
248
249template <typename Problem_, typename Policy_ = BlockFmhaV3PipelineDefaultPolicy>
251{
264
265 static_assert(std::is_same_v<SaccDataType, SMPLComputeDataType>,
266 "we will the same dist tensor 'sp_compute' for both gemm0 & softmax");
267
269
270 static constexpr ck_tile::index_t kBlockSize = Problem::kBlockSize;
271
272 static constexpr ck_tile::index_t kM0 = BlockFmhaShape::kM0;
273 static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
274 static constexpr ck_tile::index_t kK0 = BlockFmhaShape::kK0;
275 static constexpr ck_tile::index_t kN1 = BlockFmhaShape::kN1;
276 static constexpr ck_tile::index_t kK1 = BlockFmhaShape::kK1;
277 static constexpr ck_tile::index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
278 static constexpr ck_tile::index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
279
280 static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
281
282 static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
283 static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
284 static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
285 static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
286 static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
287 static constexpr bool kStoreLSE = Problem::kStoreLSE;
288
289 // last dimension vector length used to create tensor view(and decide buffer_load vector length)
290 // ... together with tensor distribution. tensor dist should able to overwrite this
291 static constexpr ck_tile::index_t kAlignmentQ =
292 kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
293 static constexpr ck_tile::index_t kAlignmentK =
294 kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
295 static constexpr ck_tile::index_t kAlignmentV =
296 kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
297
298 static constexpr ck_tile::index_t kAlignmentO =
299 kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
300
301 static constexpr ck_tile::index_t kBlockPerCu = []() {
302 if constexpr(Problem::kBlockPerCu != -1)
303 return Problem::kBlockPerCu;
304 else
305 {
306 return 2;
307 }
308 }();
309
311 {
312 // create another LDS buffer for p
313 return ck_tile::max(kM0 * kN1 * sizeof(PDataType),
314 Policy::template GetSmemSize<Problem>() +
315 kM0 * kN0 * sizeof(PDataType));
316 }
317
318 // for debug only
319 template <ck_tile::index_t MPerBlock, ck_tile::index_t NPerBlock>
320 CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc()
321 {
322 using namespace ck_tile;
323 constexpr auto lds_block_desc =
326 number<1>{},
327 number<1>{});
328
329 return lds_block_desc;
330 }
331
332 // for debug only
333 template <ck_tile::index_t MPerBlock>
334 CK_TILE_DEVICE static constexpr auto MakeSimpleLdsDesc1D()
335 {
336 using namespace ck_tile;
337 constexpr auto lds_block_desc = make_naive_tensor_descriptor(
339
340 return lds_block_desc;
341 }
342
343 template <typename DataType, typename Descriptor>
344 CK_TILE_DEVICE static constexpr auto make_lds_tile_window(void* base, const Descriptor& desc)
345 {
346 using namespace ck_tile;
347
348 auto tensor_view =
349 make_tensor_view<address_space_enum::lds>(reinterpret_cast<DataType*>(base), desc);
350 return make_tile_window(tensor_view, desc.get_lengths(), {0, 0});
351 }
352
353 // vmcnt=0~63, lgkmcnt=0~15, expcnt=0~7
354 template <uint16_t Vmcnt, uint8_t Lgkmcnt, uint8_t Expcnt = 7>
355 CK_TILE_DEVICE static constexpr void s_waitcnt()
356 {
357 // vmcnt use bits {[15:14],[3:0]}
358 // expcnt use bits [6:4]
359 // lgkmcnt use bits [11:8]
360 __builtin_amdgcn_s_waitcnt((((0b110000 & Vmcnt) << (14 - 4)) | (0b1111 & Vmcnt)) |
361 ((0b111 & Expcnt) << 4) | ((0b1111 & Lgkmcnt) << 8));
362 }
363
364 template <uint16_t Vmcnt>
365 CK_TILE_DEVICE static constexpr void s_waitcnt_vmcnt()
366 {
368 }
369
370 template <uint8_t Lgkmcnt>
371 CK_TILE_DEVICE static constexpr void s_waitcnt_lgkmcnt()
372 {
374 }
375
376 template <typename QDramBlockWindowTmp,
377 typename KDramBlockWindowTmp,
378 typename VDramBlockWindowTmp,
379 typename LSEDramBlockWindowTmp,
380 typename QElementFunction,
381 typename KElementFunction,
382 typename VElementFunction,
383 typename LSEElementFunction,
384 typename SAccElementFunction,
385 typename PComputeElementFunction,
386 typename OAccElementFunction>
387 CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
388 const QElementFunction& q_element_func,
389 const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
390 [[maybe_unused]] const KElementFunction& k_element_func,
391 const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
392 [[maybe_unused]] const VElementFunction& v_element_func,
393 LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
394 const LSEElementFunction& lse_element_func,
395 [[maybe_unused]] const SAccElementFunction& s_acc_element_func,
396 const PComputeElementFunction& p_compute_element_func,
397 const OAccElementFunction& o_acc_element_func,
398 FmhaMask mask,
399 float scale_s,
400 void* smem_ptr) const
401 {
402 using namespace ck_tile;
403
404 static_assert(
405 std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
406 std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
407 std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
408 "wrong!");
409
410 static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
411 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
412 kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
413 kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
414 kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
415 "wrong!");
416
417 static_assert(sizeof(SaccDataType) * kM0 * kN0 <= GetSmemSize());
419 reinterpret_cast<SaccDataType*>(static_cast<char*>(smem_ptr)),
421 [[maybe_unused]] auto s_lds_window =
423
425 reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr) +
426 Policy::template GetSmemSize<Problem>()),
428 [[maybe_unused]] auto p_lds_window =
430
432 reinterpret_cast<PDataType*>(static_cast<char*>(smem_ptr)),
434 [[maybe_unused]] auto o_lds_window =
436
438 reinterpret_cast<SMPLComputeDataType*>(static_cast<char*>(smem_ptr) +
439 Policy::template GetSmemSize<Problem>()),
441 [[maybe_unused]] auto m_lds_window =
443
444 const index_t warp_group_id = get_warp_id() / 4;
445
446 // Block GEMM
447 constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
448 constexpr auto gemm_1 = Policy::template GetPVBlockGemm<Problem>();
449
450 auto q_dram_window = make_tile_window_linear(
451 q_dram_block_window_tmp, Policy::template MakeQRegTileDistribution<Problem>());
452
453 // reduction function for softmax
454 const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
455 const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
456
457 auto k_lds_window_store = generate_tuple(
458 [&](auto i_buf) {
460 smem_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf));
461 },
462 number<2>{});
463
464 auto v_lds_window_store = generate_tuple(
465 [&](auto i_buf) {
467 smem_ptr, Policy::template MakeVLdsStoreBlockDescriptor<Problem>(i_buf));
468 },
469 number<2>{});
470
473 nullptr,
474 Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
475 Policy::template MakeKRegTileDistribution<Problem>())),
476 2>
477 k_lds_window_load;
478
481 nullptr,
482 Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
483 Policy::template MakeVRegTileDistribution<Problem>())),
484 2>
485 v_lds_window_load;
486
488 Policy::template MakeQRegTileDistribution<Problem>())) q_tile;
489
490 union kv_tile_type
491 {
492 CK_TILE_DEVICE kv_tile_type() {}
493
494 decltype(load_tile(k_lds_window_load(number<0>{}))) k_tile;
495
496 decltype(load_tile_transpose(v_lds_window_load(number<0>{}))) v_tile;
497 } kv_tile;
498
499 union sp_compute_type
500 {
501 CK_TILE_DEVICE sp_compute_type() {}
502
503 decltype(gemm_0.MakeCBlockTile()) sp_compute;
505 Policy::template MakePRegTileDistribution<Problem>())) p;
506 };
508
509 decltype(gemm_1.MakeCBlockTile()) o_acc;
510 constexpr index_t fmha_alu_D_reg_cnt = 6; // threshold to decide how many fmha_alu_D_upd()
511 // instructions should we move to fmha_alu1()
512 static_assert(fmha_alu_D_reg_cnt <= o_acc.thread_buf_.size());
513
515 sp(number<0>{}).sp_compute, sequence<1>{}, f_max, SMPLComputeDataType{0})) m;
516 decltype(m) l;
517
518 // initialize k_lds_window and v_lds_window
519 static_for<0, 2, 1>{}([&](auto idx) {
520 k_lds_window_load(idx) = make_tile_window(
522 static_cast<char*>(smem_ptr) + (idx)*Policy::template GetSmemSizeKV<Problem>(),
523 Policy::template MakeKLdsLoadBlockDescriptor<Problem>()),
524 Policy::template MakeKRegTileDistribution<Problem>());
525 });
526
527 static_for<0, 2, 1>{}([&](auto idx) {
528 v_lds_window_load(idx) =
530 static_cast<char*>(smem_ptr) +
531 (idx + 2) * Policy::template GetSmemSizeKV<Problem>(),
532 Policy::template MakeVLdsLoadBlockDescriptor<Problem>()),
533 Policy::template MakeVRegTileDistribution<Problem>());
534 });
535
536 {
537 auto origin_q = load_tile(q_dram_window);
538 auto transformed_q = tile_elementwise_in(q_element_func, origin_q);
539
540 q_tile = transformed_q;
541 }
542
543 clear_tile(o_acc);
544 set_tile(m, bit_cast<float>(0xff7fffff)); // a bit larger than -infinity
545 clear_tile(l);
546
547 const auto q_origin = q_dram_window.get_window_origin();
548 const auto [seqlen_k_start, seqlen_k_end] =
549 mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
550
551 const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
552 index_t kv_token_start = seqlen_k_start;
553
554 // check early exit if no work to do
555 if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
556 {
557 if(num_total_loop <= 0)
558 {
559 if constexpr(kStoreLSE)
560 {
561 auto lse =
562 make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
563
565
566 store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
567 }
568
569 // Note: here occ are all cleard, return it
570 // Note: q loaded but no fence, ignore it.
571 return o_acc;
572 }
573 }
574
575 auto k_dram_window =
576 make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
577 k_dram_block_window_tmp.get_window_lengths(),
578 {seqlen_k_start, 0},
579 Policy::template MakeKDramTileDistribution<Problem>());
580 k_dram_window.init_raw();
581
582 auto v_dram_window =
583 make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
584 v_dram_block_window_tmp.get_window_lengths(),
585 {seqlen_k_start, 0}, // TODO: hdim split?
586 Policy::template MakeVDramTileDistribution<Problem>());
587 v_dram_window.init_raw();
588
589 // prefetch K tile
590 index_t i_total_loops = 0;
591 constexpr index_t k0_loops = kQKHeaddim / kK0;
592 constexpr index_t k1_loops = kN0 / kK1;
593 static_assert(1 == k0_loops);
594 static_assert(1 == k1_loops);
595 static_assert(kN0 == kK1);
596
597 constexpr index_t NumWarpGroups = Problem::kBlockSize / Policy::NumThreadPerWarpGroup;
598 static_assert(NumWarpGroups == 2);
599
600 [[maybe_unused]] auto print_dist_tensor = [&](const auto& dist_tensor, const char* name) {
601 printf("[POYENC] %s (size=%d): %5.2f",
602 name,
603 decltype(dist_tensor.thread_buf_)::size(),
604 ck_tile::type_convert<float>(dist_tensor.thread_buf_[0]));
605 static_for<1, decltype(dist_tensor.thread_buf_)::size(), 1>{}([&](auto i) {
606 printf(", %5.2f", ck_tile::type_convert<float>(dist_tensor.thread_buf_[i]));
607 });
608 printf("\n");
609 };
610
611 [[maybe_unused]] auto print_lds = [&](auto lds_tile_window, const char* name) {
612 const auto num_rows = lds_tile_window.get_window_lengths().at(number<0>{});
613 const auto num_cols = lds_tile_window.get_window_lengths().at(number<1>{});
614
615 auto desc = lds_tile_window.get_bottom_tensor_view().desc_;
616 auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_;
617
618 if constexpr(true || num_rows < num_cols)
619 {
620 for(int row = 0; row < num_rows; ++row)
621 {
622 int offset = desc.calculate_offset(make_tuple(row, 0));
623 printf("[DEVICE] %s[%3d] = %5.2f",
624 name,
625 row,
627 for(int col = 1; col < num_cols; ++col)
628 {
629 printf(", ");
630 offset = desc.calculate_offset(make_tuple(row, col));
631 printf("%5.2f", ck_tile::type_convert<float>(data[offset]));
632 }
633 printf("\n");
634 }
635 }
636 else
637 {
638 for(int col = 0; col < num_cols; ++col)
639 {
640 int offset = desc.calculate_offset(make_tuple(0, col));
641 printf("[DEVICE] %s[%3d] = %5.2f",
642 name,
643 col,
645 for(int row = 1; row < num_rows; ++row)
646 {
647 printf(", ");
648 offset = desc.calculate_offset(make_tuple(row, col));
649 printf("%5.2f", ck_tile::type_convert<float>(data[offset]));
650 }
651 printf("\n");
652 }
653 }
654 };
655
656 [[maybe_unused]] auto print_lds_1d = [&](auto lds_tile_window, const char* name) {
657 const auto num_elems = lds_tile_window.get_window_lengths().at(number<0>{});
658
659 auto desc = lds_tile_window.get_bottom_tensor_view().desc_;
660 auto data = lds_tile_window.get_bottom_tensor_view().buf_.p_data_;
661
662 int offset = desc.calculate_offset(make_tuple(0));
663 printf("[DEVICE] %s = %5.2f", name, ck_tile::type_convert<float>(data[offset]));
664 for(int e = 1; e < num_elems; ++e)
665 {
666 printf(", ");
667 offset = desc.calculate_offset(make_tuple(e));
668 printf("%5.2f", ck_tile::type_convert<float>(data[offset]));
669 }
670 printf("\n");
671 };
672
673 // K_mem_su_ld_insts = 1 for 32 x 128
674 // V_mem_su_ld_insts = 1 for 128 x 32
675 constexpr int K_mem_su_ld_insts = k_dram_window.get_num_of_access();
676 constexpr int V_mem_su_ld_insts = v_dram_window.get_num_of_access();
677
678 auto K_mem_load = [&](auto k_lds_write_idx) {
679 async_load_tile_raw(k_lds_window_store(k_lds_write_idx), k_dram_window);
680
682 // move K tile windows
683 move_tile_window(k_dram_window, {kN0, 0});
684 };
685
686 auto K_lds_load = [&](auto k_lds_read_idx) {
687 kv_tile.k_tile = load_tile(k_lds_window_load(k_lds_read_idx));
688 };
689
690 auto V_mem_load = [&](auto v_lds_write_idx) {
691 async_load_tile_raw(v_lds_window_store(v_lds_write_idx), v_dram_window);
692
694 move_tile_window(v_dram_window, {kK1, 0});
695 };
696
697 auto V_lds_load = [&](auto v_lds_read_idx) {
698 kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx));
699 };
700
701 decltype(m) m_old;
702 SMPLComputeDataType o_acc_scale; // rescale o_acc in fmha_alu1() & fmha_alu_D_upd()
704 statically_indexed_array<decltype(sp(number<0>{}).sp_compute), 2> sp_delta;
705
706 auto fmha_alu0 = [&](auto sp_reg_idx) {
707 m_old = m; // m{j-1}
708 static_assert(m.thread_buf_.size() == 1,
709 "assuming that each thread holds 1 rowmax value");
711 sp(sp_reg_idx).sp_compute, sequence<1>{}, f_max, m.thread_buf_[0]);
712#if defined(__gfx950__)
713 // assuming that we are using 32x32 mfma
714 int32x2_t swapped_regs =
715 __builtin_amdgcn_permlane32_swap(bit_cast<int32_t>(m_latest.thread_buf_[0]),
716 bit_cast<int32_t>(m_latest.thread_buf_[0]),
717 false,
718 false);
720 m_latest.thread_buf_[0] = f_max(bit_cast<SMPLComputeDataType>(swapped_regs.x),
721 bit_cast<SMPLComputeDataType>(swapped_regs.y));
722#else
724#endif
725 m = m_latest;
726
727 constexpr auto p_spans =
728 std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
729 sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
730 sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
731 constexpr auto i_j_idx = make_tuple(idx0, idx1);
732 sp_delta(sp_reg_idx)(i_j_idx) = detail::fma_impl_vsv(
733 sp(sp_reg_idx).sp_compute(i_j_idx), scale_s, -scale_s * m(i_j_idx));
734 });
735 });
737 };
738
739 auto fmha_alu1 = [&](auto sp_reg_idx) {
740 constexpr auto p_spans =
741 std::decay_t<decltype(sp(sp_reg_idx).sp_compute)>::get_distributed_spans();
742 sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
743 sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
744 constexpr auto i_j_idx = make_tuple(idx0, idx1);
745 sp(sp_reg_idx).sp_compute(i_j_idx) =
746 ck_tile::exp2(sp_delta(sp_reg_idx)(i_j_idx));
747 });
748 });
749
751 sp(sp_reg_idx).sp_compute,
752 sequence<1>{},
753 f_sum,
754 SMPLComputeDataType{0}); // rowsum(Pcompute{j})
755 static_assert(rowsum_p.thread_buf_.size() == 1,
756 "assuming that each thread holds 1 rowsum value");
757#if defined(__gfx950__)
758 // assuming that we are using 32x32 mfma
759 int32x2_t swapped_regs =
760 __builtin_amdgcn_permlane32_swap(bit_cast<int32_t>(rowsum_p.thread_buf_[0]),
761 bit_cast<int32_t>(rowsum_p.thread_buf_[0]),
762 false,
763 false);
764 rowsum_p.thread_buf_[0] = f_sum(bit_cast<SMPLComputeDataType>(swapped_regs.x),
765 bit_cast<SMPLComputeDataType>(swapped_regs.y));
766#else
768#endif
769
770 // l{j}
775 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
776 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
777 constexpr auto i_idx = make_tuple(idx0);
778 const auto tmp = ck_tile::exp2(scale_s * (m_old[i_idx] - m[i_idx]));
779
780 l(i_idx) = detail::add_impl_vv(tmp * l[i_idx], rowsum_p[i_idx]);
781 });
782
783 // update partial o_acc [0, fmha_alu_D_reg_cnt)
785 o_acc.thread_buf_[idx] = detail::mul_impl_vv(o_acc.thread_buf_[idx], o_acc_scale);
786 });
787
792 static_assert(sp(sp_reg_idx).p.thread_buf_.size() % 2 == 0);
793 static_for<0, sp(sp_reg_idx).p.thread_buf_.size(), 2>{}([&](auto idx) {
794 float x = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx]);
795 float y = p_compute_element_func(sp(sp_reg_idx).sp_compute.thread_buf_[idx + 1]);
796 if constexpr(std::is_same_v<PDataType, fp16_t>)
797 {
798 auto casted = detail::cvt_pk_fp16_f32(x, y);
799 sp(sp_reg_idx).p.thread_buf_[idx] = casted.x;
800 sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
801 }
802 else
803 {
804 auto casted = detail::cvt_pk_bf16_f32(x, y);
805 sp(sp_reg_idx).p.thread_buf_[idx] = casted.x;
806 sp(sp_reg_idx).p.thread_buf_[idx + 1] = casted.y;
807 }
808 });
809
813 };
814
815 auto gemm = [&](auto sp_reg_idx, auto gemm_idx) {
816 if constexpr(gemm_idx == 0)
817 {
818 clear_tile(sp(sp_reg_idx).sp_compute); // initialize C
819 gemm_0(sp(sp_reg_idx).sp_compute,
820 get_slice_tile(q_tile,
821 sequence<0, (k0_loops - 1) * kK0>{},
823 get_slice_tile(kv_tile.k_tile,
824 sequence<0, (k0_loops - 1) * kK0>{},
826 }
827 else
828 {
829 gemm_1(o_acc,
830 get_slice_tile(sp(sp_reg_idx).p,
831 sequence<0, (k1_loops - 1) * kK1>{},
833 get_slice_tile(kv_tile.v_tile,
834 sequence<0, (k1_loops - 1) * kK1>{},
836 }
837 };
838
839 auto cl_calc = [&](auto sp_reg_idx, auto gemm_idx) {
840 if constexpr(gemm_idx == 0)
841 {
842 clear_tile(sp(sp_reg_idx).sp_compute); // initialize C
843 gemm_0(sp(sp_reg_idx).sp_compute,
844 get_slice_tile(q_tile,
845 sequence<0, (k0_loops - 1) * kK0>{},
847 get_slice_tile(kv_tile.k_tile,
848 sequence<0, (k0_loops - 1) * kK0>{},
850 }
851 else
852 {
853 gemm_1(o_acc,
854 get_slice_tile(sp(sp_reg_idx).p,
855 sequence<0, (k1_loops - 1) * kK1>{},
857 get_slice_tile(kv_tile.v_tile,
858 sequence<0, (k1_loops - 1) * kK1>{},
860 fmha_alu0(number<1>{} - sp_reg_idx);
861 }
862 };
863
864 auto fmha_alu_D_upd = [&] {
865 o_acc_scale = ck_tile::exp2(scale_s * (m_old.thread_buf_[0] - m.thread_buf_[0]));
866
867 fp32x2_t pk_o_acc_scale;
868 pk_o_acc_scale.x = o_acc_scale;
869 pk_o_acc_scale.y = o_acc_scale;
870
871 static_assert((o_acc.thread_buf_.size() - fmha_alu_D_reg_cnt) % 2 == 0);
872#if CK_TILE_DISABLE_PACKED_FP32
873 static_assert(fmha_alu_D_reg_cnt + 2 <= o_acc.thread_buf_.size());
875 [&](auto idx) { o_acc.thread_buf_[idx] *= o_acc_scale; });
876#endif
877
878 constexpr auto issued_D_reg_cnt =
879#if CK_TILE_DISABLE_PACKED_FP32
880 fmha_alu_D_reg_cnt + 2
881#else
882 fmha_alu_D_reg_cnt
883#endif
884 ;
887 // update partial o_acc after [issued_D_reg_cnt]
888 static_for<issued_D_reg_cnt, o_acc.thread_buf_.size(), 2>{}([&](auto idx) {
889 fp32x2_t input;
890 input.x = o_acc.thread_buf_[idx];
891 input.y = o_acc.thread_buf_[idx + 1];
892
893 auto output = detail::pk_mul_f32(input, pk_o_acc_scale);
894
895 o_acc.thread_buf_[idx] = output.x;
896 o_acc.thread_buf_[idx + 1] = output.y;
897 });
898 };
899
900 auto fmha_mask = [&](auto sp_reg_idx) {
901 if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
902 {
903 bool need_perpixel_check = mask.IsEdgeTile(
904 q_origin.at(number<0>{}), kv_token_start, number<kM0>{}, number<kN0>{});
905 if(need_perpixel_check)
906 {
907 set_tile_if(sp(sp_reg_idx).sp_compute,
909 [&](auto tile_idx) {
910 const auto row =
911 q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
912 const auto col = kv_token_start + tile_idx.at(number<1>{});
913 return mask.IsOutOfBound(row, col);
914 });
915 }
916 }
917 };
918
919 auto cl_load = [&](auto load_type, auto mem_wr_idx, auto lds_rd_idx) {
920 if constexpr(load_type == 0)
921 {
922 V_mem_load(mem_wr_idx);
923 K_lds_load(lds_rd_idx);
924 }
925 else
926 {
927 K_mem_load(mem_wr_idx);
928 V_lds_load(lds_rd_idx);
929 }
930 };
931
932 auto core_loop = [&](auto cl_p) {
933 auto gemm0 = number<0>{};
934 auto gemm1 = number<1>{};
935
936 auto memV = number<0>{};
937 auto memK = number<1>{};
938
940
941 auto iteration = [&](auto pi) {
942 auto xdl_SP_p01_reg_idx = number<1>{} - pi;
943 auto xdl_SP_p23_reg_idx = pi;
944
945 auto K_w0_lds_wr_idx = number<1>{} - pi;
946 auto V_w0_lds_wr_idx = pi;
947 auto K_w0_lds_rd_idx = pi;
948 auto V_w0_lds_rd_idx = pi;
949
950 auto K_w4_lds_wr_idx = number<1>{} - pi;
951 auto V_w4_lds_wr_idx = number<1>{} - pi;
952 auto K_w4_lds_rd_idx = number<1>{} - pi;
953 auto V_w4_lds_rd_idx = pi;
954
955 bool result = true;
956
957 if constexpr(cl_p == 0)
958 {
959#if ADD_SBARRIER_FOR_PHASE0
960 __builtin_amdgcn_sched_barrier(0);
961 __builtin_amdgcn_s_barrier();
962#endif
963 __builtin_amdgcn_sched_barrier(0);
964 // phase0
965 if constexpr(pi == 0)
966 {
967 ASM_MARKER("phase0 Wave0-3 (pi=0)");
968 }
969 else
970 {
971 ASM_MARKER("phase0 Wave0-3 (pi=1)");
972 }
974 __builtin_amdgcn_sched_barrier(0);
975 cl_calc(xdl_SP_p01_reg_idx, gemm0);
976 fmha_alu1(xdl_SP_p23_reg_idx);
977
978 Scheduler::schedule(cl_p, number<0>{});
979 __builtin_amdgcn_sched_barrier(0);
980 // phase1
981 ASM_MARKER("phase1 Wave0-3");
983 __builtin_amdgcn_sched_barrier(0);
984 __builtin_amdgcn_s_barrier();
985 __builtin_amdgcn_sched_barrier(0);
986 cl_load(memK, K_w0_lds_wr_idx, V_w0_lds_rd_idx);
987 Scheduler::schedule(cl_p, number<1>{});
988 fmha_mask(xdl_SP_p01_reg_idx);
989
990 __builtin_amdgcn_sched_barrier(0);
991 // phase2
992 ASM_MARKER("phase2 Wave0-3");
994 __builtin_amdgcn_sched_barrier(0);
995 __builtin_amdgcn_s_barrier();
996 __builtin_amdgcn_sched_barrier(0);
997 asm volatile("s_nop 0");
998 __builtin_amdgcn_sched_barrier(0);
999 cl_calc(xdl_SP_p23_reg_idx, gemm1);
1000
1001 Scheduler::schedule(cl_p, number<2>{});
1002 __builtin_amdgcn_sched_barrier(0);
1003 fmha_alu_D_upd();
1004
1005 __builtin_amdgcn_sched_barrier(0);
1006 // phase3
1007 ASM_MARKER("phase3 Wave0-3");
1009 __builtin_amdgcn_sched_barrier(0);
1010 __builtin_amdgcn_s_barrier();
1011 __builtin_amdgcn_sched_barrier(0);
1012 cl_load(memV, V_w0_lds_wr_idx, K_w0_lds_rd_idx);
1013
1014 Scheduler::schedule(cl_p, number<3>{});
1015 kv_token_start += kN0;
1016 if(num_total_loop <= ++i_total_loops)
1017 {
1018 result = false;
1019 }
1020 }
1021 else
1022 {
1023#if ADD_SBARRIER_FOR_PHASE0
1024 __builtin_amdgcn_sched_barrier(0);
1025 __builtin_amdgcn_s_barrier();
1026#endif
1027 __builtin_amdgcn_sched_barrier(0);
1028 // phase0
1029 if constexpr(pi == 0)
1030 {
1031 ASM_MARKER("phase0 Wave4-7 (pi=0)");
1032 }
1033 else
1034 {
1035 ASM_MARKER("phase0 Wave4-7 (pi=1)");
1036 }
1037 cl_load(memV, V_w4_lds_wr_idx, K_w4_lds_rd_idx);
1038
1039 Scheduler::schedule(cl_p, number<0>{});
1040 __builtin_amdgcn_sched_barrier(0);
1041 // phase1
1042 ASM_MARKER("phase1 Wave4-7");
1044 __builtin_amdgcn_sched_barrier(0);
1045 __builtin_amdgcn_s_barrier();
1046 __builtin_amdgcn_sched_barrier(0);
1047 asm volatile("s_nop 1");
1048 __builtin_amdgcn_sched_barrier(0);
1049 cl_calc(xdl_SP_p01_reg_idx, gemm0);
1050 fmha_alu1(xdl_SP_p23_reg_idx);
1051
1052 Scheduler::schedule(cl_p, number<1>{});
1053 __builtin_amdgcn_sched_barrier(0);
1054 // phase2
1055 ASM_MARKER("phase2 Wave4-7");
1056 __builtin_amdgcn_s_barrier();
1057 __builtin_amdgcn_sched_barrier(0);
1058 cl_load(memK, K_w4_lds_wr_idx, V_w4_lds_rd_idx);
1059 Scheduler::schedule(cl_p, number<2>{});
1060 fmha_mask(xdl_SP_p01_reg_idx);
1061
1062 kv_token_start += kN0;
1063 if(num_total_loop <= ++i_total_loops)
1064 {
1065 result = false;
1066 }
1067
1068 __builtin_amdgcn_sched_barrier(0);
1069 // phase3
1070 ASM_MARKER("phase3 Wave4-7");
1072 __builtin_amdgcn_sched_barrier(0);
1073 __builtin_amdgcn_s_barrier();
1074 __builtin_amdgcn_sched_barrier(0);
1075 asm volatile("s_nop 1");
1076 __builtin_amdgcn_sched_barrier(0);
1077 cl_calc(xdl_SP_p23_reg_idx, gemm1);
1078
1079 Scheduler::schedule(cl_p, number<3>{});
1080 __builtin_amdgcn_sched_barrier(0);
1081 fmha_alu_D_upd();
1082 }
1083 return result;
1084 };
1085 return iteration(number<0>{}) && iteration(number<1>{});
1086 };
1087
1088 auto fmha_post_process = [&](auto d) {
1089 auto ps_pi = number<1>{} - d;
1090 auto V_lds_rd_idx = ps_pi;
1091
1092 if(1 < num_total_loop)
1093 {
1095 }
1096 else
1097 {
1099 }
1100 __builtin_amdgcn_s_barrier();
1101
1102 V_lds_load(V_lds_rd_idx);
1103 fmha_alu1(ps_pi);
1104
1106
1107 auto xdl_SP_p23_reg_idx = ps_pi;
1108 gemm(xdl_SP_p23_reg_idx, /*gemm_idx=*/number<1>{});
1109 };
1110
1111 // pre-stage
1112 {
1113 ASM_MARKER("before pre-stage");
1114 // (1) load K0 to LDS & VGPR
1115 K_mem_load(number<0>{}); // mem_K0
1116
1118 __builtin_amdgcn_s_barrier();
1119
1120 K_lds_load(number<0>{}); // lds_K0
1121
1123 __builtin_amdgcn_s_barrier();
1124
1125 // (2) prefetch K1 and V0 to LDS in parallel with GEMM0
1126 if(1 < num_total_loop)
1127 {
1128 K_mem_load(number<1>{}); // mem_K1
1129 }
1130 V_mem_load(number<0>{}); // mem_V0
1131
1132 // (3) mfma (Q*K0) + softmax
1133 gemm(number<0>{}, /*gemm_idx=*/number<0>{});
1134
1135 fmha_mask(number<0>{});
1137 fmha_alu0(number<0>{});
1138 fmha_alu_D_upd();
1139
1140 kv_token_start += kN0;
1141 ++i_total_loops;
1142 if(num_total_loop <= i_total_loops)
1143 {
1144 goto label_main_loops_exit;
1145 }
1146
1147 if(2 < num_total_loop)
1148 {
1149 K_mem_load(number<0>{}); // mem_K2
1150
1152 __builtin_amdgcn_s_barrier();
1153 }
1154
1155 ASM_MARKER("end pre-stage");
1156 }
1157
1158 if(1 < num_total_loop)
1159 {
1160 if(warp_group_id == 0)
1161 {
1162 V_mem_load(number<1>{}); // V1
1163 K_lds_load(number<1>{}); // K1
1164
1165 __builtin_amdgcn_s_setprio(0);
1166 __builtin_amdgcn_s_barrier();
1167 while(core_loop(number<0>{}))
1168 ;
1169 }
1170 if(warp_group_id != 0)
1171 {
1172 __builtin_amdgcn_s_setprio(1);
1173 __builtin_amdgcn_s_barrier();
1174 while(core_loop(number<1>{}))
1175 ;
1176 }
1177 }
1178 label_main_loops_exit:
1179 if(num_total_loop % 2)
1180 {
1181 fmha_post_process(number<1>{});
1182 }
1183 if(!(num_total_loop % 2))
1184 {
1185 fmha_post_process(number<0>{});
1186 }
1187
1188 // store lse
1189 if constexpr(kStoreLSE)
1190 {
1191 auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
1192
1193 constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
1194 sweep_tile_span(lse_spans[number<0>{}], [&](auto idx0) {
1195 constexpr auto i_idx = make_tuple(idx0);
1196 lse(i_idx) = m[i_idx] / C_LOG2E + log(l[i_idx]);
1197 });
1198
1199 store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
1200 }
1201
1202 // finally, O
1203 constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
1204
1205 sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
1206 constexpr auto i_idx = make_tuple(idx0);
1207 const auto tmp = [&]() {
1208 if constexpr(FmhaMask::IsMasking)
1209 {
1210 return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
1211 }
1212 else
1213 return 1 / l[i_idx];
1214 }();
1215 sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
1216 constexpr auto i_j_idx = make_tuple(idx0, idx1);
1217 o_acc(i_j_idx) *= tmp;
1218 });
1219 });
1220
1221 o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
1222
1223 return o_acc;
1224 }
1225
1226 template <typename QDramBlockWindowTmp,
1227 typename KDramBlockWindowTmp,
1228 typename VDramBlockWindowTmp,
1229 typename LSEDramBlockWindowTmp>
1230 CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
1231 const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
1232 const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
1233 LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
1234 FmhaMask mask,
1235 float scale_s,
1236 void* smem_ptr) const
1237 {
1238 using namespace ck_tile;
1239
1240 return operator()(q_dram_block_window_tmp,
1241 identity{},
1242 k_dram_block_window_tmp,
1243 identity{},
1244 v_dram_block_window_tmp,
1245 identity{},
1246 lse_dram_block_window_tmp,
1247 identity{},
1248 identity{},
1249 identity{},
1250 identity{},
1251 mask,
1252 scale_s,
1253 smem_ptr);
1254 }
1255};
1256
1257} // namespace ck_tile
#define ASM_MARKER(marker)
Definition block_fmha_fwd_v3_pipeline.hpp:17
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_DEVICE bf16x2_t cvt_pk_bf16_f32(float a, float b)
Definition block_fmha_fwd_v3_pipeline.hpp:230
CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c)
Definition block_fmha_fwd_v3_pipeline.hpp:190
CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs)
Definition block_fmha_fwd_v3_pipeline.hpp:212
CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
Definition block_fmha_fwd_v3_pipeline.hpp:203
CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b)
Definition block_fmha_fwd_v3_pipeline.hpp:221
CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs)
Definition block_fmha_fwd_v3_pipeline.hpp:239
Definition tile/core/algorithm/cluster_descriptor.hpp:13
_Float16 fp16x2_t
Definition half.hpp:385
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition bfloat16.hpp:428
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition block_reduce.hpp:21
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
bfloat16_t bf16x2_t
Definition pk_fp4.hpp:24
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window)
transpose loads tile from a tensor and returns the resulting tensor with a new (transposed) tile dist...
Definition load_tile_transpose.hpp:403
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto generate_tuple(F &&f, number< N >)
Definition tile/core/container/tuple.hpp:429
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:191
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE constexpr auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition tile_window_linear.hpp:993
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
float fp32x2_t
Definition pk_fp4.hpp:22
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition load_tile.hpp:133
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
int32_t int32x2_t
Definition vector_type.hpp:154
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition bfloat16.hpp:425
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition block_fmha_fwd_v3_pipeline.hpp:251
static constexpr bool kPadSeqLenQ
Definition block_fmha_fwd_v3_pipeline.hpp:283
ck_tile::remove_cvref_t< typename Problem::PDataType > PDataType
Definition block_fmha_fwd_v3_pipeline.hpp:260
static constexpr bool kPadHeadDimV
Definition block_fmha_fwd_v3_pipeline.hpp:286
ck_tile::remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition block_fmha_fwd_v3_pipeline.hpp:258
static constexpr ck_tile::index_t kQKHeaddim
Definition block_fmha_fwd_v3_pipeline.hpp:277
ck_tile::remove_cvref_t< typename Problem::KDataType > KDataType
Definition block_fmha_fwd_v3_pipeline.hpp:255
ck_tile::remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_fwd_v3_pipeline.hpp:262
ck_tile::remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition block_fmha_fwd_v3_pipeline.hpp:268
ck_tile::remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition block_fmha_fwd_v3_pipeline.hpp:261
static constexpr ck_tile::index_t kAlignmentV
Definition block_fmha_fwd_v3_pipeline.hpp:295
static constexpr ck_tile::index_t kN0
Definition block_fmha_fwd_v3_pipeline.hpp:273
static constexpr ck_tile::index_t kK0
Definition block_fmha_fwd_v3_pipeline.hpp:274
ck_tile::remove_cvref_t< Problem_ > Problem
Definition block_fmha_fwd_v3_pipeline.hpp:252
static CK_TILE_DEVICE constexpr void s_waitcnt()
Definition block_fmha_fwd_v3_pipeline.hpp:355
static CK_TILE_DEVICE constexpr void s_waitcnt_lgkmcnt()
Definition block_fmha_fwd_v3_pipeline.hpp:371
static CK_TILE_DEVICE constexpr auto make_lds_tile_window(void *base, const Descriptor &desc)
Definition block_fmha_fwd_v3_pipeline.hpp:344
static constexpr ck_tile::index_t kAlignmentK
Definition block_fmha_fwd_v3_pipeline.hpp:293
static constexpr ck_tile::index_t kAlignmentO
Definition block_fmha_fwd_v3_pipeline.hpp:298
static CK_TILE_DEVICE constexpr auto MakeSimpleLdsDesc()
Definition block_fmha_fwd_v3_pipeline.hpp:320
ck_tile::remove_cvref_t< typename Problem::QDataType > QDataType
Definition block_fmha_fwd_v3_pipeline.hpp:254
static constexpr bool kIsGroupMode
Definition block_fmha_fwd_v3_pipeline.hpp:282
static constexpr ck_tile::index_t kSubQKHeaddim
Definition block_fmha_fwd_v3_pipeline.hpp:278
ck_tile::remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_fwd_v3_pipeline.hpp:259
static CK_TILE_DEVICE constexpr auto MakeSimpleLdsDesc1D()
Definition block_fmha_fwd_v3_pipeline.hpp:334
ck_tile::remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition block_fmha_fwd_v3_pipeline.hpp:257
ck_tile::remove_cvref_t< Policy_ > Policy
Definition block_fmha_fwd_v3_pipeline.hpp:253
static constexpr bool kPadSeqLenK
Definition block_fmha_fwd_v3_pipeline.hpp:284
ck_tile::remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition block_fmha_fwd_v3_pipeline.hpp:263
static constexpr ck_tile::index_t kN1
Definition block_fmha_fwd_v3_pipeline.hpp:275
static constexpr ck_tile::index_t kK1
Definition block_fmha_fwd_v3_pipeline.hpp:276
static constexpr bool kPadHeadDimQ
Definition block_fmha_fwd_v3_pipeline.hpp:285
static constexpr bool kStoreLSE
Definition block_fmha_fwd_v3_pipeline.hpp:287
static constexpr ck_tile::index_t kM0
Definition block_fmha_fwd_v3_pipeline.hpp:272
static constexpr ck_tile::index_t kBlockPerCu
Definition block_fmha_fwd_v3_pipeline.hpp:301
static constexpr ck_tile::index_t kAlignmentQ
Definition block_fmha_fwd_v3_pipeline.hpp:291
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, const KElementFunction &k_element_func, const VDramBlockWindowTmp &v_dram_block_window_tmp, const VElementFunction &v_element_func, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, float scale_s, void *smem_ptr) const
Definition block_fmha_fwd_v3_pipeline.hpp:387
static CK_TILE_DEVICE constexpr void s_waitcnt_vmcnt()
Definition block_fmha_fwd_v3_pipeline.hpp:365
ck_tile::remove_cvref_t< typename Problem::VDataType > VDataType
Definition block_fmha_fwd_v3_pipeline.hpp:256
static constexpr ck_tile::index_t kBlockSize
Definition block_fmha_fwd_v3_pipeline.hpp:270
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_fwd_v3_pipeline.hpp:310
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, float scale_s, void *smem_ptr) const
Definition block_fmha_fwd_v3_pipeline.hpp:1230
static CK_TILE_DEVICE constexpr void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition block_fmha_fwd_v3_pipeline.hpp:119
static CK_TILE_DEVICE constexpr void schedule(ck_tile::number< WaveGroup >, ck_tile::number< Phase >)
Definition block_fmha_fwd_v3_pipeline.hpp:45
Definition block_fmha_fwd_v3_pipeline.hpp:39
Definition tile/core/utility/functional.hpp:86
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition coordinate_transform.hpp:1392
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tensor_view.hpp:41
#define C_LOG2E
Definition tile/core/numeric/math.hpp:469