block_fmha_bwd_pipeline_default_policy.hpp Source File

block_fmha_bwd_pipeline_default_policy.hpp Source File#

Composable Kernel: block_fmha_bwd_pipeline_default_policy.hpp Source File
block_fmha_bwd_pipeline_default_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
19
20namespace ck_tile {
21
23{
24 template <index_t ndim>
25 static constexpr auto swap_last2 = generate_sequence_v2(
26 [](auto i) {
27 return number < i == ndim - 2 ? ndim - 1 : i == ndim - 1 ? ndim - 2 : i > {};
28 },
29 number<ndim>{});
30
31 template <typename Problem>
32 CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
33 {
34 using GemmProblem =
35 BlockGemmProblem<typename Problem::QDataType,
36 typename Problem::KDataType,
37 typename Problem::AccDataType,
38 Problem::kBlockSize,
39 TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
40 Problem::BlockFmhaShape::kN0,
41 Problem::BlockFmhaShape::kK0>,
42 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
43 typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
44
45 using WarpGemm = WarpGemmDispatcher<
46 typename Problem::QDataType,
47 typename Problem::KDataType,
48 typename Problem::AccDataType,
49 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
50 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
51 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
52 false,
53 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
54
55 using BlockGemmPolicy =
56 BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::QDataType,
57 typename Problem::KDataType,
58 typename Problem::AccDataType,
59 typename Problem::BlockFmhaShape::Gemm0BlockWarps,
60 WarpGemm>;
61
63 }
64
65 template <typename Problem>
66 CK_TILE_DEVICE static constexpr auto GetPTOGradTBlockGemm()
67 {
68 using GemmProblem =
69 BlockGemmProblem<typename Problem::GemmDataType,
70 typename Problem::OGradDataType,
71 typename Problem::AccDataType,
72 Problem::kBlockSize,
73 TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
74 Problem::BlockFmhaShape::kVHeaddim,
75 Problem::BlockFmhaShape::kK1>,
76 typename Problem::BlockFmhaShape::Gemm1BlockWarps,
77 typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
78
79 using WarpGemm =
80 WarpGemmDispatcher<typename Problem::GemmDataType,
81 typename Problem::OGradDataType,
82 typename Problem::AccDataType,
83 Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
84 Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
85 Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
86 true,
87 false, // SwizzleAccess
88 false, // UseStructuredSparsity
89 (Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
90 ? WGAttrNumAccessEnum ::Double
91 : WGAttrNumAccessEnum ::Single>;
92
93 using BlockGemmPolicy =
94 BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
95 typename Problem::OGradDataType,
96 typename Problem::AccDataType,
97 typename Problem::BlockFmhaShape::Gemm1BlockWarps,
98 WarpGemm>;
99
101 }
102
103 template <typename Problem>
105 {
106 using GemmProblem =
107 BlockGemmProblem<typename Problem::OGradDataType,
108 typename Problem::VDataType,
109 typename Problem::AccDataType,
110 Problem::kBlockSize,
111 TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
112 Problem::BlockFmhaShape::kN0,
113 Problem::BlockFmhaShape::kK2>,
114 typename Problem::BlockFmhaShape::Gemm2BlockWarps,
115 typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
116
117 using WarpGemm = WarpGemmDispatcher<
118 typename Problem::OGradDataType,
119 typename Problem::VDataType,
120 typename Problem::AccDataType,
121 Problem::BlockFmhaShape::Gemm2WarpTile::at(number<0>{}),
122 Problem::BlockFmhaShape::Gemm2WarpTile::at(number<1>{}),
123 Problem::BlockFmhaShape::Gemm2WarpTile::at(number<2>{}),
124 false,
125 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
126
127 using BlockGemmPolicy =
128 BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::OGradDataType,
129 typename Problem::VDataType,
130 typename Problem::AccDataType,
131 typename Problem::BlockFmhaShape::Gemm2BlockWarps,
132 WarpGemm>;
133
135 }
136
137 template <typename Problem>
139 {
140 using GemmProblem =
141 BlockGemmProblem<typename Problem::GemmDataType,
142 typename Problem::QDataType,
143 typename Problem::AccDataType,
144 Problem::kBlockSize,
145 TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
146 Problem::BlockFmhaShape::kQKHeaddim,
147 Problem::BlockFmhaShape::kK3>,
148 typename Problem::BlockFmhaShape::Gemm3BlockWarps,
149 typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
150
151 using WarpGemm =
152 WarpGemmDispatcher<typename Problem::GemmDataType,
153 typename Problem::QDataType,
154 typename Problem::AccDataType,
155 Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
156 Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
157 Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
158 true,
159 false, // SwizzleAccess
160 false, // UseStructuredSparsity
161 (Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32)
162 ? WGAttrNumAccessEnum ::Double
163 : WGAttrNumAccessEnum ::Single>;
164
165 using BlockGemmPolicy =
166 BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
167 typename Problem::QDataType,
168 typename Problem::AccDataType,
169 typename Problem::BlockFmhaShape::Gemm3BlockWarps,
170 WarpGemm>;
171
173 }
174
175 template <typename Problem>
177 {
178 using GemmProblem =
179 BlockGemmProblem<typename Problem::GemmDataType,
180 typename Problem::KDataType,
181 typename Problem::AccDataType,
182 Problem::kBlockSize,
183 TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
184 Problem::BlockFmhaShape::kQKHeaddim,
185 Problem::BlockFmhaShape::kK4>,
186 typename Problem::BlockFmhaShape::Gemm4BlockWarps,
187 typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
188
189 using WarpGemm = WarpGemmDispatcher<typename Problem::GemmDataType,
190 typename Problem::KDataType,
191 typename Problem::AccDataType,
192 Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
193 Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
194 Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
195 false>;
196
197 using BlockGemmPolicy =
198 BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
199 typename Problem::KDataType,
200 typename Problem::AccDataType,
201 typename Problem::BlockFmhaShape::Gemm4BlockWarps,
202 WarpGemm>;
203
205 }
206
207 // these are for global load
208 template <typename Problem>
209 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
210 {
212 constexpr index_t kBlockSize = Problem::kBlockSize;
213 constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
214 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
215 constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType);
216 constexpr index_t kMinVecLoad = 4 / sizeof(QDataType);
217
218 constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
219
220 constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
221 ? kMaxVecLoad
222 : (total_pixels / kMinVecLoad);
223
224 return kVecLoad;
225 }
226
227 template <typename Problem>
228 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
229 {
231 constexpr index_t kBlockSize = Problem::kBlockSize;
232 constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
233 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
234 constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType);
235 constexpr index_t kMinVecLoad = 4 / sizeof(KDataType);
236
237 constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
238
239 constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
240 ? kMaxVecLoad
241 : (total_pixels / kMinVecLoad);
242
243 return kVecLoad;
244 }
245
246 template <typename Problem>
247 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
248 {
250 constexpr index_t kBlockSize = Problem::kBlockSize;
251 constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
252 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
253 constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType);
254 constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
255
256 return total_pixels > kMaxVecLoad ? kMaxVecLoad : total_pixels;
257 }
258
259 template <typename Problem>
260 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentO()
261 {
263 return 16 / sizeof(ODataType);
264 }
265
266 template <typename Problem>
268 {
270 constexpr index_t kBlockSize = Problem::kBlockSize;
271 constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
272 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
273 constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType);
274 constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType);
275
276 constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
277
278 constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
279 ? kMaxVecLoad
280 : (total_pixels / kMinVecLoad);
281
282 return kVecLoad;
283 }
284
285 template <typename Problem>
286 CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias()
287 {
289 constexpr index_t kBlockSize = Problem::kBlockSize;
290 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
291 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
292 constexpr index_t kMaxVecLoad = 16 / sizeof(BiasDataType);
293 constexpr index_t kMinVecLoad = 4 / sizeof(BiasDataType);
294
295 constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize;
296
297 constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
298 ? kMaxVecLoad
299 : (total_pixels / kMinVecLoad);
300
301 return kVecLoad;
302 }
303
304 template <typename Problem>
306 {
308 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
309 using WG = remove_cvref_t<decltype(config.template at<0>())>;
310 using CWarpDstr = typename WG::CWarpDstr;
311 constexpr auto vec =
312 CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number<CWarpDstr::NDimY - 1>{});
313 return vec;
314 }
315
316 template <typename Problem>
318 {
320 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
321 using WG = remove_cvref_t<decltype(config.template at<0>())>;
322 using CWarpDstr = typename WG::CWarpDstr;
323 constexpr auto vec =
324 CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number<CWarpDstr::NDimY - 1>{});
325 return vec;
326 }
327
328 template <typename Problem>
330 {
331 constexpr index_t kBlockSize = Problem::kBlockSize;
332 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
333 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
334
335 constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
336
337 return total_pixels / GetAlignmentQ<Problem>();
338 }
339
340 template <typename Problem>
342 {
343 constexpr index_t kBlockSize = Problem::kBlockSize;
344 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
345 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
346 constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
347
348 return total_pixels / GetAlignmentK<Problem>();
349 }
350
351 template <typename Problem>
353 {
354 constexpr index_t kBlockSize = Problem::kBlockSize;
355 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
356 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
357
358 constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
359
360 return total_pixels / GetAlignmentOGrad<Problem>();
361 }
362
363 template <typename Problem>
365 {
366 constexpr index_t kBlockSize = Problem::kBlockSize;
367 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
368 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
369
370 constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize;
371
372 return total_pixels / GetAlignmentBias<Problem>();
373 }
374
375 template <typename Problem>
377 {
379 return 16 / sizeof(AccDataType);
380 }
381
382 template <typename Problem>
384 {
386 }
387
388 template <typename Problem>
390 {
391 constexpr index_t kBlockSize = Problem::kBlockSize;
392
393 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
394 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
395
396 constexpr index_t K1 = GetAlignmentK<Problem>();
397 constexpr index_t K0 = kKPerBlock / K1;
398 constexpr index_t N1 = get_warp_size() / K0;
399 constexpr index_t N0 = kBlockSize / get_warp_size();
400 constexpr index_t N2 = kNPerBlock / (N1 * N0);
401
402 constexpr auto dstr = make_static_tile_distribution(
408 sequence<2, 1>>{});
409
410 if constexpr((kKPerBlock & (kKPerBlock - 1)) == 0) // kKPerBlock is power of 2
411 {
412 return dstr;
413 }
414 else
415 {
416 constexpr index_t kKPerIter = 32;
417 static_assert(kKPerBlock % kKPerIter == 0);
418 constexpr index_t K0_m = kKPerBlock / kKPerIter;
419 constexpr index_t K2 = 2;
420 constexpr index_t K1_m = kKPerIter / K2;
421 constexpr index_t N1_m = get_warp_size() / K1_m;
422 constexpr index_t N2_m = kNPerBlock / (N1_m * N0);
423 constexpr auto dstr_m = make_static_tile_distribution(
427 tuple<sequence<1>, sequence<1, 2>>, // N0, N1 K1
429 sequence<2, 1, 2>, // K0 N2 K2
431 static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
432 kNPerBlock * kKPerBlock);
433 return dstr_m;
434 }
435 }
436
437 template <typename Problem>
439 {
440 constexpr index_t kBlockSize = Problem::kBlockSize;
441
442 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
443 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
444
445 constexpr index_t K1 = GetAlignmentV<Problem>();
446 constexpr index_t K0 = kKPerBlock / K1;
447 constexpr index_t N2 = get_warp_size() / K0;
448 constexpr index_t N1 = kBlockSize / get_warp_size();
449 constexpr index_t N0 = kNPerBlock / (N2 * N1);
450
451 constexpr auto dstr = make_static_tile_distribution(
454 tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K0
456 sequence<1, 2>, // N0 K1
457 sequence<0, 1>>{});
458 if constexpr((kKPerBlock & (kKPerBlock - 1)) == 0) // kKPerBlock is power of 2
459 {
460 return dstr;
461 }
462 else
463 {
464 constexpr index_t kKPerIter = 32;
465 static_assert(kKPerBlock % kKPerIter == 0);
466 constexpr index_t K0_m = kKPerBlock / kKPerIter;
467 constexpr index_t K2 = 2;
468 constexpr index_t K1_m = kKPerIter / K2;
469 constexpr index_t N2_m = get_warp_size() / K1_m;
470 constexpr index_t N0_m = kNPerBlock / (N2_m * N1);
471 constexpr auto dstr_m = make_static_tile_distribution(
475 tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K1
477 sequence<2, 1, 2>, // K0 N0 K2
479 static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
480 kNPerBlock * kKPerBlock);
481 return dstr_m;
482 }
483 }
484
485 template <typename Problem>
487 {
488 constexpr index_t kBlockSize = Problem::kBlockSize;
489
490 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
491 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
492
493 constexpr index_t K1 = GetAlignmentQ<Problem>();
494 constexpr index_t K0 = kKPerBlock / K1;
495 constexpr index_t M1 = get_warp_size() / K0;
496 constexpr index_t M0 = kBlockSize / get_warp_size();
497 constexpr index_t M2 = kMPerBlock / (M1 * M0);
498
499 constexpr auto dstr = make_static_tile_distribution(
505 sequence<2, 1>>{});
506
507 if constexpr((kKPerBlock & (kKPerBlock - 1)) == 0) // kKPerBlock is power of 2
508 {
509 return dstr;
510 }
511 else
512 {
513 // something not divisible, try a more flexible distribution
514 constexpr index_t kKPerIter = 32;
515 static_assert(kKPerBlock % kKPerIter == 0);
516 constexpr index_t K0_m = kKPerBlock / kKPerIter;
517 constexpr index_t K2 = 2;
518 constexpr index_t K1_m = kKPerIter / K2;
519 constexpr index_t M1_m = get_warp_size() / K1_m;
520 constexpr index_t M2_m = kMPerBlock / (M1_m * M0);
521 constexpr auto dstr_m = make_static_tile_distribution(
525 tuple<sequence<1>, sequence<1, 2>>, // M0, M1 K1
527 sequence<2, 1, 2>, // K0 M2 K2
529 static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
530 kMPerBlock * kKPerBlock);
531 return dstr_m;
532 }
533 }
534
535 template <typename Problem>
537 {
538 constexpr index_t kBlockSize = Problem::kBlockSize;
539
540 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
541 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
542
543 constexpr index_t K1 = GetAlignmentOGrad<Problem>();
544 constexpr index_t K0 = kKPerBlock / K1;
545 constexpr index_t M1 = get_warp_size() / K0;
546 constexpr index_t M0 = kBlockSize / get_warp_size();
547 constexpr index_t M2 = kMPerBlock / (M1 * M0);
548
549 constexpr auto dstr = make_static_tile_distribution(
555 sequence<2, 1>>{});
556
557 if constexpr((kKPerBlock & (kKPerBlock - 1)) == 0) // kKPerBlock is power of 2
558 {
559 return dstr;
560 }
561 else
562 {
563 // something not divisible, try a more flexible distribution
564 constexpr index_t kKPerIter = 32;
565 static_assert(kKPerBlock % kKPerIter == 0);
566 constexpr index_t K0_m = kKPerBlock / kKPerIter;
567 constexpr index_t K2 = 2;
568 constexpr index_t K1_m = kKPerIter / K2;
569 constexpr index_t M1_m = get_warp_size() / K1_m;
570 constexpr index_t M2_m = kMPerBlock / (M1_m * M0);
571 constexpr auto dstr_m = make_static_tile_distribution(
575 tuple<sequence<1>, sequence<1, 2>>, // M0, M1 K1
577 sequence<2, 1, 2>, // K0 M2 K2
579 static_assert(container_reduce(dstr_m.get_lengths(), std::multiplies<index_t>{}, 1) ==
580 kMPerBlock * kKPerBlock);
581 return dstr_m;
582 }
583 }
584
585 template <typename Problem, typename BlockGemm>
587 {
588 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
589 constexpr index_t MWarp = config.template at<1>();
590 constexpr index_t NWarp = config.template at<2>();
591
592 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
593
594 // Duplicate dimension
595 constexpr index_t N0 = NWarp;
596 constexpr index_t N1 =
597 (get_warp_size() / kMPerBlock) > 1 ? (get_warp_size() / kMPerBlock) : 1;
598
599 constexpr index_t M0 = MWarp;
600 constexpr index_t M1 = (get_warp_size() / kMPerBlock) > 1 ? kMPerBlock : get_warp_size();
601 constexpr index_t M2 =
602 (get_warp_size() / kMPerBlock) > 1 ? 1 : (kMPerBlock / get_warp_size());
603
610 sequence<2>>{});
611 }
612
613 template <typename Problem>
615 {
616 constexpr index_t kBlockSize = Problem::kBlockSize;
617
618 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
619 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
620
621 constexpr index_t N1 = GetAlignmentBias<Problem>();
622 constexpr index_t N0 = kNPerBlock / N1;
623 constexpr index_t M1 = get_warp_size() / N0;
624 constexpr index_t M0 = kBlockSize / get_warp_size();
625 constexpr index_t M2 = kMPerBlock / (M1 * M0);
626
627 constexpr auto dstr = make_static_tile_distribution(
633 sequence<2, 1>>{});
634 static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
635 kMPerBlock * kNPerBlock);
636 return dstr;
637 }
638
639 template <typename DataType, index_t MPerBlock, index_t KPerBlock>
641 {
642 constexpr index_t K1 = 16 / sizeof(DataType);
643 constexpr index_t K0 = KPerBlock / K1;
644 constexpr index_t M2 = 1;
645 constexpr index_t M1 = get_warp_size();
646 constexpr index_t M0 = MPerBlock / M1;
647
648 constexpr auto dstr = make_static_tile_distribution(
655 static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
656 MPerBlock * KPerBlock);
657 return dstr;
658 }
659
660 template <typename Problem>
662 {
664
665 constexpr index_t kBlockSize = Problem::kBlockSize;
666 constexpr index_t kKPerBlock = Problem::kVHeaddim;
667
669 }
670
671 template <typename Problem>
673 {
675
676 constexpr index_t kBlockSize = Problem::kBlockSize;
677 constexpr index_t kKPerBlock = Problem::kVHeaddim;
678
680 }
681
682 template <typename Problem>
684 {
685 constexpr index_t kBlockSize = Problem::kBlockSize;
686 constexpr index_t kMPerBlock = Problem::kM0;
687 constexpr index_t kKPerBlock = Problem::kQKHeaddim;
688
690 constexpr index_t K1 = min(kKPerBlock / K2, get_warp_size());
691 constexpr index_t K0 = kKPerBlock / (K1 * K2);
692
693 constexpr index_t M2 = get_warp_size() / K1;
694 constexpr index_t M1 = kBlockSize / get_warp_size();
695 constexpr index_t M0 = kMPerBlock / (M1 * M2);
696
697 constexpr auto dstr = make_static_tile_distribution(
705 static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
706 kMPerBlock * kKPerBlock);
707 return dstr;
708 }
709
710 template <typename Problem>
712 {
713 constexpr index_t kBlockSize = Problem::kBlockSize;
714 constexpr index_t kMPerBlock = Problem::kM0;
715 constexpr index_t kKPerBlock = Problem::kQKHeaddim;
716
718 constexpr index_t K1 = min(kKPerBlock / K2, get_warp_size());
719 constexpr index_t K0 = kKPerBlock / (K1 * K2);
720
721 constexpr index_t M2 = get_warp_size() / K1;
722 constexpr index_t M1 = kBlockSize / get_warp_size();
723 constexpr index_t M0 = kMPerBlock / (M1 * M2);
724
725 constexpr auto dstr = make_static_tile_distribution(
732 static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
733 kMPerBlock * kKPerBlock);
734 return dstr;
735 }
736
737 // these are for lds
738 template <typename Problem>
739 CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
740 {
741 return GetAlignmentQ<Problem>();
742 }
743
744 template <typename Problem>
745 CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQT()
746 {
748 }
749
750 template <typename Problem>
751 CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK()
752 {
753 return GetAlignmentK<Problem>();
754 }
755
756 template <typename Problem>
757 CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackKT()
758 {
760 }
761
762 template <typename Problem>
763 CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
764 {
765 return GetAlignmentV<Problem>();
766 }
767
768 template <typename Problem>
769 CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBias()
770 {
772 }
773
774 template <typename Problem>
776 {
778 }
779
780 template <typename Problem>
782 {
784 }
785
786 template <typename Problem>
788 {
790 }
791
792 template <typename Problem>
794 {
795 // TODO: this is for 3d layout
797 return 16 / sizeof(GemmDataType);
798 }
799
800 template <index_t KIter, index_t MNPerBlock, index_t KPerSubBlock, index_t KPack>
802 {
803 constexpr auto DataTypeSize = 2; // sizeof(F16/BF16)
804 constexpr auto MNLdsLayer =
805 (32 * 4 / KPerSubBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerSubBlock / DataTypeSize);
806
807 constexpr auto x_lds_block_desc_0 =
809 number<KPerSubBlock / KPack * MNLdsLayer>{},
810 number<MNPerBlock / MNLdsLayer>{},
811 number<KPack>{}),
815 number<1>{}),
817 number<1>{});
818
819 constexpr auto x_lds_block_desc_permuted = transform_tensor_descriptor(
820 x_lds_block_desc_0,
823 number<KPerSubBlock / KPack * MNLdsLayer>{})),
827
828 constexpr auto x_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
829 x_lds_block_desc_permuted,
837
838 constexpr auto x_lds_block_desc = transform_tensor_descriptor(
839 x_lds_block_desc_xk0_mnldslayer_mn_xk1,
843 number<KIter>{}, number<KPerSubBlock / KPack>{}, number<KPack>{}))),
846
847 static_assert(container_reduce(x_lds_block_desc.get_lengths(),
848 std::multiplies<index_t>{},
849 1) == KIter * MNPerBlock * KPerSubBlock);
850 return x_lds_block_desc;
851 }
852
853 template <index_t MNPerBlock, index_t KPerBlock, index_t KPack>
858 template <typename Problem,
859 index_t MNPerBlock,
860 index_t KPerBlock,
861 index_t KPack,
862 index_t KPackT>
867 template <typename Problem,
868 index_t MNIter,
869 index_t MNPerSubBlock,
870 index_t KPerBlock,
871 index_t KPack,
872 index_t KPackT>
874 {
875 // kfold and mpair dimension is not always required.
876 // more dimension in merge_transform increase the difficulty of generating immarg offset
877 // for compiler.
878 constexpr auto MNPerXDL = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
879 constexpr auto kBlockSize = Problem::kBlockSize;
880
881 constexpr auto MN0 = MNPerSubBlock / KPack;
882 constexpr auto MN1 = KPack;
883
884 constexpr auto KThreadWrite = kBlockSize / MN0;
885 constexpr auto K0Number = KPerBlock / KPackT;
886 constexpr auto K0PerThreadWrite = K0Number / KThreadWrite;
887 constexpr auto KThreadRead = get_warp_size() / MNPerXDL; // assume 32x32x8 mfma
888 constexpr auto K0PerThreadRead = K0Number / KThreadRead;
889
890 constexpr auto kfold = (KPackT * MN0 * 2 > 128) ? 1 : 128 / (KPackT * MN0 * 2);
891 constexpr auto KThreadReadPerm =
892 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
893 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
894 : KThreadRead;
895
896 // 1<=mnpair<=n0
897 constexpr auto mnpair =
898 (KPackT * MNPerXDL * 2 > 128)
899 ? 1
900 : ((128 / (KPackT * MNPerXDL * 2)) > MN0 ? MN0 : 128 / (KPackT * MNPerXDL * 2));
901
902 constexpr auto xt_lds_block_desc_raw = make_naive_tensor_descriptor(
904 number<KThreadWrite / kfold / KThreadReadPerm>{},
907 number<kfold * MN0 / mnpair>{},
909 KPackT),
916 number<1>{}),
918 number<1>{});
919
920 constexpr auto xt_lds_block_desc_permuted = transform_tensor_descriptor(
921 xt_lds_block_desc_raw,
927 make_tuple(number<KThreadReadPerm * MN1>{}, number<kfold * MN0 / mnpair>{})),
931 sequence<1>{},
932 sequence<2>{},
934 sequence<5>{},
935 sequence<6>{}),
937 sequence<1>{},
938 sequence<2>{},
940 sequence<5>{},
941 sequence<6>{}));
942
943 constexpr auto xt_lds_block_desc_unmerged = transform_tensor_descriptor(
944 xt_lds_block_desc_permuted,
954 sequence<1>{},
955 sequence<2>{},
956 sequence<3>{},
957 sequence<4>{},
958 sequence<5>{},
959 sequence<6>{}),
961 sequence<2>{},
962 sequence<3>{},
965 sequence<7>{},
966 sequence<8>{}));
967
968 constexpr auto xt_lds_block_desc = transform_tensor_descriptor(
969 xt_lds_block_desc_unmerged,
973 number<KThreadWrite / kfold / KThreadReadPerm>{},
976 number<KPackT>{})),
978 number<MNIter>{}, number<MN0 / mnpair>{}, number<mnpair>{}, number<MN1>{}))),
981 static_assert(container_reduce(xt_lds_block_desc.get_lengths(),
982 std::multiplies<index_t>{},
983 1) == MNPerSubBlock * MNIter * KPerBlock);
984 return xt_lds_block_desc;
985 }
986
987 template <typename Problem>
989 {
990 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
991 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
992
993 using dram_encoding = typename decltype(MakeKDramTileDistribution<Problem>())::DstrEncode;
994 constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
995 if constexpr(dram_y_ndim == 2)
996 {
997 constexpr index_t kKPack = GetSmemKPackK<Problem>();
999 }
1000 else if constexpr(dram_y_ndim == 3)
1001 {
1002 constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
1003 constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
1004 return MakeXLdsBlockDescriptor<KIter, kNPerBlock, kKPerBlock / KIter, kKPack>();
1005 }
1006 else
1007 {
1008 static_assert(false, "Unexpected dram y dimension");
1009 }
1010 }
1011
1012 template <typename Problem>
1014 {
1016 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1017 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
1018
1019 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
1020 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
1021
1022 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
1023 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
1024
1025 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
1026 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1027
1028 constexpr auto k_block_outer_dstr_encoding =
1034 sequence<0, 0>>{};
1035
1036 constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
1037 k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
1038
1039 constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
1040 static_assert(container_reduce(k_block_dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
1041 kNPerBlock * kKPerBlock);
1042 return k_block_dstr;
1043 }
1044
1045 template <typename Problem>
1047 {
1048 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
1049 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
1050
1051 using dram_encoding = typename decltype(MakeVDramTileDistribution<Problem>())::DstrEncode;
1052 constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
1053 if constexpr(dram_y_ndim == 2)
1054 {
1055 constexpr index_t kVPack = GetSmemKPackV<Problem>();
1057 }
1058 else if constexpr(dram_y_ndim == 3)
1059 {
1060 constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
1061 constexpr index_t kVPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
1062 return MakeXLdsBlockDescriptor<KIter, kNPerBlock, kKPerBlock / KIter, kVPack>();
1063 }
1064 else
1065 {
1066 static_assert(false, "Unexpected dram y dimension");
1067 }
1068 }
1069
1070 template <typename Problem>
1072 {
1074 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1075 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
1076
1077 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
1078 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
1079
1080 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
1081 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
1082
1083 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
1084 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1085
1086 constexpr auto v_block_outer_dstr_encoding =
1092 sequence<0, 0>>{};
1093
1094 constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
1095 v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
1096
1097 constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
1098 static_assert(container_reduce(v_block_dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
1099 kNPerBlock * kKPerBlock);
1100 return v_block_dstr;
1101 }
1102
1103 template <typename Problem>
1105 {
1106 using dram_encoding = typename decltype(MakeKDramTileDistribution<Problem>())::DstrEncode;
1107 constexpr index_t y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
1108 static_assert(y_ndim >= 2);
1109 using shuffled_encoding_t =
1112 return make_static_tile_distribution(shuffled_encoding_t{});
1113 }
1114
1115 template <typename Problem>
1117 {
1118 // Hold all data
1119 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
1120 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
1121
1122 using dram_encoding = typename decltype(MakeKDramTileDistribution<Problem>())::DstrEncode;
1123 constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
1124 if constexpr(dram_y_ndim == 2)
1125 {
1126 constexpr index_t kKPack = GetSmemKPackK<Problem>();
1127 constexpr index_t kKPackT = GetSmemKPackKT<Problem>();
1129 }
1130 else if constexpr(dram_y_ndim == 3)
1131 {
1132 constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
1133 constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
1134 constexpr index_t kKPackT = typename dram_encoding::HsLengthss{}.at(number<0>{}).at(2);
1135 return MakeXTLdsBlockDescriptor<Problem,
1136 KIter,
1137 kNPerBlock / KIter,
1138 kKPerBlock,
1139 kKPack,
1140 kKPackT>();
1141 }
1142 else
1143 {
1144 static_assert(false, "Unexpected dram y dimension");
1145 }
1146 }
1147
1148 template <typename Problem>
1150 {
1151 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
1152 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
1153
1154 auto shuffled_k_lds_block_desc = MakeShuffledKLdsWriteBlockDescriptor<Problem>();
1155
1157 shuffled_k_lds_block_desc,
1162 }
1163
1164 template <typename Problem>
1166 {
1168 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1169 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
1170
1171 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
1172 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
1173
1174 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
1175 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
1176
1177 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
1178 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1179
1180 constexpr auto kt_block_outer_dstr_encoding =
1186 sequence<0, 0>>{};
1187
1188 constexpr auto kt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
1189 kt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
1190
1191 constexpr auto kt_block_dstr = make_static_tile_distribution(kt_block_dstr_encode);
1192 static_assert(container_reduce(kt_block_dstr.get_lengths(),
1193 std::multiplies<index_t>{},
1194 1) == kNPerBlock * kKPerBlock);
1195 return kt_block_dstr;
1196 }
1197
1198 template <typename Problem>
1200 {
1201 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
1202 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
1203
1204 using dram_encoding = typename decltype(MakeQDramTileDistribution<Problem>())::DstrEncode;
1205 constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
1206 if constexpr(dram_y_ndim == 2)
1207 {
1208 constexpr index_t kKPack = GetSmemKPackQ<Problem>();
1210 }
1211 else if constexpr(dram_y_ndim == 3)
1212 {
1213 constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
1214 constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
1215 return MakeXLdsBlockDescriptor<KIter, kMPerBlock, kKPerBlock / KIter, kKPack>();
1216 }
1217 else
1218 {
1219 static_assert(false, "Unexpected dram y dimension");
1220 }
1221 }
1222
1223 template <typename Problem>
1225 {
1227 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1228 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
1229
1230 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
1231 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
1232
1233 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
1234 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
1235
1236 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
1237 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1238
1239 constexpr auto q_block_outer_dstr_encoding =
1245 sequence<0, 0>>{};
1246
1247 constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
1248 q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
1249
1250 constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode);
1251 static_assert(container_reduce(q_block_dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
1252 kMPerBlock * kKPerBlock);
1253 return q_block_dstr;
1254 }
1255
1256 template <typename Problem>
1258 {
1259 using dram_encoding = typename decltype(MakeQDramTileDistribution<Problem>())::DstrEncode;
1260 constexpr index_t y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
1261 static_assert(y_ndim >= 2);
1262 using shuffled_encoding_t =
1265 return make_static_tile_distribution(shuffled_encoding_t{});
1266 }
1267
1268 template <typename Problem>
1270 {
1271 // Hold full block data
1272 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
1273 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
1274
1275 using dram_encoding = typename decltype(MakeQDramTileDistribution<Problem>())::DstrEncode;
1276 constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
1277 if constexpr(dram_y_ndim == 2)
1278 {
1279 constexpr index_t kKPack = GetSmemKPackQ<Problem>();
1280 constexpr index_t kKPackT = GetSmemKPackQT<Problem>();
1282 }
1283 else if constexpr(dram_y_ndim == 3)
1284 {
1285 constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
1286 constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
1287 constexpr index_t kKPackT = typename dram_encoding::HsLengthss{}.at(number<0>{}).at(2);
1288 return MakeXTLdsBlockDescriptor<Problem,
1289 KIter,
1290 kNPerBlock / KIter,
1291 kKPerBlock,
1292 kKPack,
1293 kKPackT>();
1294 }
1295 else
1296 {
1297 static_assert(false, "Unexpected dram y dimension");
1298 }
1299 }
1300
1301 template <typename Problem>
1303 {
1304 // Hold full block data
1305 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
1306 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
1307
1308 auto shuffled_q_lds_block_desc = MakeShuffledQLdsWriteBlockDescriptor<Problem>();
1309
1311 shuffled_q_lds_block_desc,
1316 }
1317
1318 template <typename Problem>
1320 {
1322 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1323 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
1324
1325 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{});
1326 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{});
1327
1328 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
1329 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
1330
1331 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
1332 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1333
1334 constexpr auto qt_block_outer_dstr_encoding =
1340 sequence<0, 0>>{};
1341
1342 constexpr auto qt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
1343 qt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
1344
1345 constexpr auto qt_block_dstr = make_static_tile_distribution(qt_block_dstr_encode);
1346 static_assert(container_reduce(qt_block_dstr.get_lengths(),
1347 std::multiplies<index_t>{},
1348 1) == kNPerBlock * kKPerBlock);
1349
1350 return qt_block_dstr;
1351 }
1352
1353 template <typename Problem>
1355 {
1357 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1358 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
1359
1360 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{});
1361 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{});
1362
1363 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
1364 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
1365
1366 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
1367 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1368
1369 constexpr auto dst_block_outer_dstr_encoding =
1375 sequence<0, 0>>{};
1376
1377 constexpr auto dst_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
1378 dst_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
1379
1380 constexpr auto dst_block_dstr = make_static_tile_distribution(dst_block_dstr_encode);
1381 static_assert(container_reduce(dst_block_dstr.get_lengths(),
1382 std::multiplies<index_t>{},
1383 1) == kMPerBlock * kKPerBlock);
1384 return dst_block_dstr;
1385 }
1386
1387 template <typename Problem>
1389 {
1390 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
1392 constexpr index_t kMPack = 16 / sizeof(LSEDType);
1393
1394 constexpr auto lsed_lds_block_desc =
1398 number<1>{});
1399
1400 return lsed_lds_block_desc;
1401 }
1402
1403 template <typename Problem, typename BlockGemm>
1405 {
1406 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1407 using WG = remove_cvref_t<decltype(config.template at<0>())>;
1408 constexpr index_t MWarp = config.template at<1>();
1409 constexpr index_t NWarp = config.template at<2>();
1410
1411 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
1412
1413 constexpr index_t N1 = WG::WarpGemmAttribute::Impl::kCNLane;
1414 constexpr index_t N0 = NWarp;
1415
1416 // M4 *2 and M2 /2 when swizzle mode enabled
1417 constexpr index_t SwizzleConfig = WG::kM == 16 ? 1 : 2;
1418 // constexpr index_t SwizzleConfig = 1;
1419 constexpr index_t M4 = WG::WarpGemmAttribute::Impl::kCM1PerLane * SwizzleConfig;
1420 constexpr index_t M3 = WG::WarpGemmAttribute::Impl::kCMLane;
1421 constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kCM0PerLane / SwizzleConfig;
1422 constexpr index_t M1 = MWarp;
1423 constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM);
1424
1425 constexpr auto dstr = make_static_tile_distribution(
1432 static_assert(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
1433 kMPerBlock);
1434 return dstr;
1435 }
1436
1437 template <typename Problem>
1439 {
1440 // Hold full block data
1441 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
1442 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
1443
1444 using dram_encoding =
1445 typename decltype(MakeOGradDramTileDistribution<Problem>())::DstrEncode;
1446 constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
1447 if constexpr(dram_y_ndim == 2)
1448 {
1449 constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
1451 }
1452 else if constexpr(dram_y_ndim == 3)
1453 {
1454 constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
1455 constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
1456 return MakeXLdsBlockDescriptor<KIter, kMPerBlock, kKPerBlock / KIter, kKPack>();
1457 }
1458 else
1459 {
1460 static_assert(false, "Unexpected dram y dimension");
1461 }
1462 }
1463
1464 template <typename Problem>
1466 {
1468 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1469 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
1470
1471 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
1472 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
1473
1474 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
1475 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
1476
1477 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
1478 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1479
1480 constexpr auto do_block_outer_dstr_encoding =
1486 sequence<0, 0>>{};
1487
1488 constexpr auto do_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
1489 do_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
1490
1491 constexpr auto do_block_dstr = make_static_tile_distribution(do_block_dstr_encode);
1492 static_assert(container_reduce(do_block_dstr.get_lengths(),
1493 std::multiplies<index_t>{},
1494 1) == kMPerBlock * kKPerBlock);
1495 return do_block_dstr;
1496 }
1497
1498 template <typename Problem>
1500 {
1501
1502 using dram_encoding =
1503 typename decltype(MakeOGradDramTileDistribution<Problem>())::DstrEncode;
1504 constexpr index_t y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
1505 static_assert(y_ndim >= 2);
1506 using shuffled_encoding_t =
1509 return make_static_tile_distribution(shuffled_encoding_t{});
1510 }
1511
1512 template <typename Problem>
1514 {
1515 // Hold all data
1516 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
1517 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
1518
1519 using dram_encoding =
1520 typename decltype(MakeOGradDramTileDistribution<Problem>())::DstrEncode;
1521 constexpr index_t dram_y_ndim = typename dram_encoding::Ys2RHsMajor{}.size();
1522 if constexpr(dram_y_ndim == 2)
1523 {
1524 constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
1525 constexpr index_t kKPackT = GetSmemKPackOGradT<Problem>();
1527 }
1528 else if constexpr(dram_y_ndim == 3)
1529 {
1530 constexpr index_t KIter = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(0);
1531 constexpr index_t kKPack = typename dram_encoding::HsLengthss{}.at(number<1>{}).at(2);
1532 constexpr index_t kKPackT = typename dram_encoding::HsLengthss{}.at(number<0>{}).at(2);
1533 return MakeXTLdsBlockDescriptor<Problem,
1534 KIter,
1535 kNPerBlock / KIter,
1536 kKPerBlock,
1537 kKPack,
1538 kKPackT>();
1539 }
1540 else
1541 {
1542 static_assert(false, "Unexpected dram y dimension");
1543 }
1544 }
1545
1546 template <typename Problem>
1548 {
1549 // Hold all data
1550 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
1551 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
1552 auto shuffled_do_lds_block_desc = MakeShuffledOGradLdsWriteBlockDescriptor<Problem>();
1553
1555 shuffled_do_lds_block_desc,
1560 }
1561
1562 template <typename Problem>
1564 {
1566 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1567 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
1568
1569 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
1570 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
1571
1572 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
1573 // constexpr index_t kNPerBlock = 32;
1574 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
1575
1576 constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
1577 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1578
1579 constexpr auto dot_block_outer_dstr_encoding =
1585 sequence<0, 0>>{};
1586
1587 constexpr auto dot_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
1588 dot_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
1589
1590 constexpr auto dot_block_dstr = make_static_tile_distribution(dot_block_dstr_encode);
1591 static_assert(container_reduce(dot_block_dstr.get_lengths(),
1592 std::multiplies<index_t>{},
1593 1) == kNPerBlock * kKPerBlock);
1594 return dot_block_dstr;
1595 }
1596
1597 template <typename Problem>
1599 {
1601 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1602 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
1603
1604 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
1605 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
1606
1607 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
1608 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
1609
1610 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
1611 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1612
1613 constexpr auto pt_block_outer_dstr_encoding =
1619 sequence<0, 0>>{};
1620
1621 constexpr auto pt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
1622 pt_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
1623
1624 constexpr auto pt_block_dstr = make_static_tile_distribution(pt_block_dstr_encode);
1625 static_assert(container_reduce(pt_block_dstr.get_lengths(),
1626 std::multiplies<index_t>{},
1627 1) == kMPerBlock * kKPerBlock);
1628 return pt_block_dstr;
1629 }
1630
1631 template <typename Problem>
1633 {
1634 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
1635 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
1636 constexpr index_t kKPack = GetSmemKPackSGrad<Problem>();
1637
1639 }
1640
1641 template <typename Problem>
1643 {
1645 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1646 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
1647
1648 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
1649 constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
1650
1651 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
1652 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK4;
1653
1654 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
1655 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1656
1657 constexpr auto ds_block_outer_dstr_encoding =
1663 sequence<0, 0>>{};
1664
1665 constexpr auto ds_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
1666 ds_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
1667
1668 constexpr auto ds_block_dstr = make_static_tile_distribution(ds_block_dstr_encode);
1669 static_assert(container_reduce(ds_block_dstr.get_lengths(),
1670 std::multiplies<index_t>{},
1671 1) == kMPerBlock * kKPerBlock);
1672 return ds_block_dstr;
1673 }
1674
1675 template <typename Problem, typename PTOutTensor, typename PInTensor>
1676 CK_TILE_DEVICE static constexpr void PTFromGemm0CToGemm1A(PTOutTensor& pt_out,
1677 const PInTensor& p_in)
1678 {
1679 if constexpr(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}) == 16)
1680 {
1682 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1683 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
1684
1685 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
1686
1687 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
1688 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
1689
1690 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
1691 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1692
1693 using AWarpDstr = typename WarpGemm::AWarpDstr;
1694 using CWarpDstr = typename WarpGemm::CWarpDstr;
1695 auto pt_warp_tensor =
1697
1698 constexpr auto a_warp_y_lengths =
1699 to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
1700 constexpr auto c_warp_y_lengths =
1701 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
1702
1703 constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
1704 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
1705
1706 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
1707 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
1708 pt_warp_tensor.get_thread_buffer() = p_in.get_y_sliced_thread_data(
1709 merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
1710 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
1711
1712 pt_out.set_y_sliced_thread_data(
1713 merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
1714 merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
1715 pt_warp_tensor.get_thread_buffer());
1716 });
1717 });
1718 }
1719 else
1720 {
1721 pt_out.get_thread_buffer() = p_in.get_thread_buffer();
1722 }
1723 }
1724
1725 template <typename Problem, typename SGradTOutTensor, typename SGradInTensor>
1726 CK_TILE_DEVICE static constexpr void SGradTFromGemm2CToGemm3A(SGradTOutTensor& dst_out,
1727 const SGradInTensor& ds_in)
1728 {
1729 if constexpr(Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}) == 16)
1730 {
1732 constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
1733 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
1734
1735 constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{});
1736
1737 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
1738 constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
1739
1740 constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
1741 constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
1742
1743 using AWarpDstr = typename WarpGemm::AWarpDstr;
1744 using CWarpDstr = typename WarpGemm::CWarpDstr;
1745 auto dst_warp_tensor =
1747
1748 constexpr auto a_warp_y_lengths =
1749 to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
1750 constexpr auto c_warp_y_lengths =
1751 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
1752
1753 constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
1754 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
1755
1756 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
1757 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
1758 dst_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data(
1759 merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
1760 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
1761
1762 dst_out.set_y_sliced_thread_data(
1763 merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
1764 merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
1765 dst_warp_tensor.get_thread_buffer());
1766 });
1767 });
1768 }
1769 else
1770 {
1771 dst_out.get_thread_buffer() = ds_in.get_thread_buffer();
1772 }
1773 }
1774
1775 template <typename Problem>
1777 {
1778 constexpr index_t kBlockSize = Problem::kBlockSize;
1779
1780 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
1781
1782 constexpr index_t N1 = GetAlignmentBias<Problem>();
1783 constexpr index_t N0 = kNPerBlock / N1;
1785 constexpr index_t M1 = get_warp_size() / N0;
1786 constexpr index_t M0 = kBlockSize / get_warp_size();
1787
1794 sequence<1, 2>>{});
1795 }
1796
1797 template <typename Problem>
1799 {
1800 // Hold full block data
1801 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
1802 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
1803
1804 constexpr index_t kKPack = GetSmemKPackBias<Problem>();
1805 constexpr index_t kKPackT = GetSmemKPackBiasT<Problem>();
1806
1808 }
1809
1810 template <typename BlockGemm>
1812 {
1813 using c_block_tensor_type = decltype(BlockGemm{}.MakeCBlockTile());
1814 return c_block_tensor_type::get_tile_distribution();
1815 }
1816
1817 template <typename Problem>
1819 {
1820 constexpr index_t smem_size_q = sizeof(typename Problem::QDataType) *
1821 MakeQLdsBlockDescriptor<Problem>().get_element_space_size();
1822 return smem_size_q;
1823 }
1824
1825 template <typename Problem>
1827 {
1828 constexpr index_t smem_size_qt =
1829 sizeof(typename Problem::QDataType) *
1830 MakeShuffledQLdsWriteBlockDescriptor<Problem>().get_element_space_size();
1831
1832 return smem_size_qt;
1833 }
1834
1835 template <typename Problem>
1837 {
1838 constexpr index_t smem_size_k =
1839 sizeof(typename Problem::KDataType) *
1840 MakeKLdsWriteBlockDescriptor<Problem>().get_element_space_size();
1841 return smem_size_k;
1842 }
1843
1844 template <typename Problem>
1846 {
1847 constexpr index_t smem_size_kt =
1848 sizeof(typename Problem::KDataType) *
1849 MakeKTLdsReadBlockDescriptor<Problem>().get_element_space_size();
1850 return smem_size_kt;
1851 }
1852
1853 template <typename Problem>
1855 {
1856 constexpr index_t smem_size_lse =
1857 sizeof(typename Problem::LSEDataType) *
1858 MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size();
1859 return smem_size_lse;
1860 }
1861
1862 template <typename Problem>
1864 {
1865 constexpr index_t smem_size_d =
1866 sizeof(typename Problem::DDataType) *
1867 MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size();
1868 return smem_size_d;
1869 }
1870
1871 template <typename Problem>
1873 {
1874 constexpr index_t smem_size_v =
1875 sizeof(typename Problem::VDataType) *
1876 MakeVLdsWriteBlockDescriptor<Problem>().get_element_space_size();
1877 return smem_size_v;
1878 }
1879
1880 template <typename Problem>
1882 {
1883 constexpr index_t smem_size_do =
1884 sizeof(typename Problem::OGradDataType) *
1885 MakeOGradLdsBlockDescriptor<Problem>().get_element_space_size();
1886 return smem_size_do;
1887 }
1888
1889 template <typename Problem>
1891 {
1892 constexpr index_t smem_size_dot =
1893 sizeof(typename Problem::OGradDataType) *
1894 MakeShuffledOGradLdsWriteBlockDescriptor<Problem>().get_element_space_size();
1895 return smem_size_dot;
1896 }
1897
1898 template <typename Problem>
1900 {
1901 constexpr index_t smem_size_ds =
1902 sizeof(typename Problem::GemmDataType) *
1903 MakeSGradLdsBlockDescriptor<Problem>().get_element_space_size();
1904 return smem_size_ds;
1905 }
1906
1907 template <typename Problem>
1909 {
1910 constexpr index_t smem_size_bias = [&]() {
1911 if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
1912 return sizeof(typename Problem::BiasDataType) *
1913 MakeBiasLdsBlockDescriptor<Problem>().get_element_space_size();
1914 else
1915 return 0;
1916 }();
1917 return smem_size_bias;
1918 }
1919
1920 template <typename Problem>
1922 {
1923 constexpr index_t smem_size_q = GetSmemSizeQ<Problem>();
1924 constexpr index_t smem_size_qt = GetSmemSizeQT<Problem>();
1925 constexpr index_t smem_size_lse = GetSmemSizeLSE<Problem>();
1926 constexpr index_t smem_size_k = GetSmemSizeK<Problem>();
1927 constexpr index_t smem_size_kt = GetSmemSizeKT<Problem>();
1928 constexpr index_t smem_size_v = GetSmemSizeV<Problem>();
1929 constexpr index_t smem_size_do = GetSmemSizeOGrad<Problem>();
1930 constexpr index_t smem_size_dot = GetSmemSizeOGradT<Problem>();
1931 constexpr index_t smem_size_d = GetSmemSizeD<Problem>();
1932 constexpr index_t smem_size_ds = GetSmemSizeSGrad<Problem>();
1933 constexpr index_t smem_size_bias = GetSmemSizeBias<Problem>();
1934
1935 constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt;
1936 constexpr index_t smem_size_stage0_1 = smem_size_v;
1937 constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + smem_size_dot +
1938 smem_size_do + smem_size_lse + smem_size_d +
1939 max(smem_size_bias, smem_size_ds);
1940
1941 return max(smem_size_stage0_0, smem_size_stage0_1, smem_size_stage1);
1942 }
1943
1944 template <typename Problem_>
1946 {
1947 using Problem = Problem_;
1948
1949 template <index_t GemmStage>
1950 CK_TILE_DEVICE static constexpr void GemmStagedScheduler()
1951 {
1952 }
1953
1954 template <>
1956 {
1957 // Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
1958 // Comp: Q x K
1959 constexpr index_t VMEM_READ_INST =
1960 Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ;
1961 constexpr index_t LDS_READ_INST = OGradT_LDS_READ;
1962 constexpr index_t MFMA_INST = Gemm0MFMA;
1963
1964 // Evenly distributed to relieve SQ->TA FIFO pressure
1965 constexpr index_t MFMA_PER_VMEM_READ = MFMA_INST / VMEM_READ_INST;
1966 constexpr index_t MFMA_Remainder = MFMA_INST - MFMA_PER_VMEM_READ * VMEM_READ_INST;
1967 // To hide instruction issue latency
1968 constexpr index_t LDS_READ_PER_MFMA = LDS_READ_INST / MFMA_INST;
1969
1970 static_for<0, VMEM_READ_INST, 1>{}([&](auto i) {
1971 ignore = i;
1972 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
1974 ignore = j;
1975 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
1976 __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
1977 });
1978 });
1979 static_for<0, MFMA_Remainder, 1>{}([&](auto i) {
1980 ignore = i;
1981 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
1982 __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
1983 });
1984 }
1985
1986 template <>
1988 {
1989 // Mem: Q^T LDS load
1990 // Comp: OGrad x V
1991 constexpr index_t LDS_READ_INST = QT_LDS_READ;
1992 constexpr index_t MFMA_INST = Gemm1MFMA;
1993
1994 // To hide instruction issue latency
1995 constexpr index_t LDS_READ_PER_MFMA = LDS_READ_INST / MFMA_INST;
1996
1997 static_for<0, MFMA_INST, 1>{}([&](auto i) {
1998 ignore = i;
1999 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
2000 __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
2001 });
2002 }
2003
2004 template <>
2006 {
2007 // Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store
2008 // Comp: PT x OGrad
2009 constexpr index_t LDS_WRITE_INST = Q_LDS_WRITE + QT_LDS_WRITE + OGrad_LDS_WRITE +
2010 OGradT_LDS_WRITE + LSE_LDS_WRITE + D_LDS_WRITE;
2011 constexpr index_t MFMA_INST = Gemm2MFMA;
2012
2013 // To hide instruction issue latency
2014 constexpr index_t LDS_WRITE_PER_MFMA = LDS_WRITE_INST / MFMA_INST;
2015
2016 static_for<0, MFMA_INST, 1>{}([&](auto i) {
2017 ignore = i;
2018 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
2019 __builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS write
2020 });
2021 }
2022
2023 template <>
2025 {
2026 // Mem: SGradT LDS store, SGrad, Q, LSE LDS load.
2027 // Comp: SGradT x QT
2028 constexpr index_t LDS_WRITE_INST = SGradT_LDS_WRITE;
2029 constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ + LSE_LDS_READ;
2030 constexpr index_t MFMA_INST = Gemm3MFMA;
2031
2032 // To hide instruction issue latency
2033 constexpr index_t LDS_WRITE_PER_MFMA =
2034 LDS_WRITE_INST / MFMA_INST >= 1 ? LDS_WRITE_INST / MFMA_INST : 1;
2035 constexpr index_t MFMA_INST_LDS_WRITE = LDS_WRITE_INST / LDS_WRITE_PER_MFMA;
2036
2037 constexpr index_t LDS_READ_PER_MFMA =
2038 (MFMA_INST - MFMA_INST_LDS_WRITE) > 0
2039 ? LDS_READ_INST / (MFMA_INST - MFMA_INST_LDS_WRITE) > 0
2040 ? LDS_READ_INST / (MFMA_INST - MFMA_INST_LDS_WRITE)
2041 : 1
2042 : 0;
2043
2045 ignore = i;
2046 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
2047 __builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS Write
2048 });
2049
2050 static_for<0, MFMA_INST - MFMA_INST_LDS_WRITE, 1>{}([&](auto i) {
2051 ignore = i;
2052 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
2053 __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read
2054 });
2055 }
2056
2057 template <>
2059 {
2060 // Mem: SGrad, OGrad, D LDS load.
2061 // Comp: SGrad x KT
2062 constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ + D_LDS_READ;
2063 constexpr index_t MFMA_INST = Gemm4MFMA;
2064
2065 // To hide instruction issue latency
2066 constexpr index_t LDS_READ_PER_MFMA =
2067 LDS_READ_INST / MFMA_INST > 0 ? LDS_READ_INST / MFMA_INST : 1;
2068
2069 static_for<0, MFMA_INST, 1>{}([&](auto i) {
2070 ignore = i;
2071 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
2072 __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read
2073 });
2074 }
2075
2076 private:
2077 static constexpr index_t kBlockSize = Problem::kBlockSize;
2078 static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0;
2079 static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0;
2080 static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim;
2081 static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim;
2082 static constexpr index_t kK0 = Problem::BlockFmhaShape::kK0;
2083 static constexpr index_t kK2 = Problem::BlockFmhaShape::kK2;
2084 static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4;
2085
2086 static constexpr index_t WarpGemmM =
2087 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
2088 static constexpr index_t WarpGemmN =
2089 Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{});
2090 static constexpr index_t WarpGemmK = WarpGemmM == 16 ? 16 : 8;
2091 static constexpr index_t Gemm4MWarp =
2092 Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
2093 static constexpr index_t Gemm4NWarp =
2094 Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
2095
2096 // Compute
2097 static constexpr index_t Gemm0MFMA =
2098 kM0 * kN0 * kK0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
2099 static constexpr index_t Gemm1MFMA =
2100 kN0 * kVHeaddim * kM0 /
2101 (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
2102 static constexpr index_t Gemm2MFMA =
2103 kM0 * kN0 * kK2 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
2104 static constexpr index_t Gemm3MFMA =
2105 kN0 * kQKHeaddim * kM0 /
2106 (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
2107 static constexpr index_t Gemm4MFMA =
2108 kM0 * kQKHeaddim * kN0 /
2109 (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
2110
2111 // VMEM
2112 static constexpr index_t Q_VMEM_READ =
2113 kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ<Problem>();
2114 static constexpr index_t OGrad_VMEM_READ =
2115 kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
2116 static constexpr index_t LSE_VMEM_READ = 1;
2117 static constexpr index_t D_VMEM_READ = 1;
2118
2119 // LDS Read
2120 static constexpr index_t OGradT_LDS_READ =
2122 static constexpr index_t QT_LDS_READ =
2123 kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
2124 static constexpr index_t SGradT_LDS_READ_P1 =
2125 kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
2126 static constexpr index_t Q_LDS_READ = kM0 * kK0 / kBlockSize / GetAlignmentQ<Problem>();
2127 static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
2128 static constexpr index_t SGradT_LDS_READ_P2 =
2129 kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
2130 static constexpr index_t OGrad_LDS_READ =
2131 kM0 * kK2 / kBlockSize / GetAlignmentOGrad<Problem>();
2132 static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
2133
2134 // LDS Write
2135 static constexpr index_t Q_LDS_WRITE =
2136 kM0 * kQKHeaddim / Problem::kBlockSize / GetAlignmentQ<Problem>();
2137 static constexpr index_t QT_LDS_WRITE =
2138 kM0 * kQKHeaddim / kBlockSize / GetTransposedAlignmentQ<Problem>();
2139 static constexpr index_t OGrad_LDS_WRITE =
2140 kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
2141 static constexpr index_t OGradT_LDS_WRITE =
2142 kM0 * kVHeaddim / kBlockSize / GetTransposedAlignmentOGrad<Problem>();
2143 static constexpr index_t LSE_LDS_WRITE = 1;
2144 static constexpr index_t D_LDS_WRITE = 1;
2145 static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize;
2146 };
2147};
2148
2149} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition tile/core/container/container_helper.hpp:198
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
CK_TILE_HOST_DEVICE constexpr auto generate_sequence_v2(F &&f, number< N >)
Definition tile/core/container/sequence.hpp:1045
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
typename tile_distribution_encoding_shuffle< encoding, shuffle >::type tile_distribution_encoding_shuffle_t
Definition tile_distribution_encoding.hpp:451
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
CK_TILE_HOST_DEVICE constexpr auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition coordinate_transform.hpp:1622
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1662
constexpr detail::ignore_t ignore
Definition tile/core/utility/ignore.hpp:20
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1609
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_HOST_DEVICE constexpr T min(T x)
Definition tile/core/numeric/math.hpp:210
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_fmha_bwd_pipeline_default_policy.hpp:1946
Problem_ Problem
Definition block_fmha_bwd_pipeline_default_policy.hpp:1947
static CK_TILE_DEVICE constexpr void GemmStagedScheduler()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1950
Definition block_fmha_bwd_pipeline_default_policy.hpp:23
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeOGrad()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1881
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledOGradLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1513
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledOGradRegWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1499
static CK_TILE_HOST_DEVICE constexpr auto MakeOGradLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1438
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentPostQGradAcc()
Definition block_fmha_bwd_pipeline_default_policy.hpp:376
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledKLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1116
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackKT()
Definition block_fmha_bwd_pipeline_default_policy.hpp:757
static CK_TILE_HOST_DEVICE constexpr auto MakeQRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1224
static CK_TILE_HOST_DEVICE constexpr auto MakeKDramTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:389
static CK_TILE_HOST_DEVICE constexpr auto MakeSGradTRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1354
static CK_TILE_HOST_DEVICE constexpr auto GetOGradVBlockGemm()
Definition block_fmha_bwd_pipeline_default_policy.hpp:104
static CK_TILE_HOST_DEVICE constexpr auto MakeKTLdsReadBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1149
static CK_TILE_HOST_DEVICE constexpr auto MakeOGradTRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1563
static CK_TILE_HOST_DEVICE constexpr auto MakeLSEDLdsReadBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1404
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentO()
Definition block_fmha_bwd_pipeline_default_policy.hpp:260
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackBiasT()
Definition block_fmha_bwd_pipeline_default_policy.hpp:775
static CK_TILE_HOST_DEVICE constexpr auto GetQKBlockGemm()
Definition block_fmha_bwd_pipeline_default_policy.hpp:32
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeBias()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1908
static CK_TILE_HOST_DEVICE constexpr auto MakeOGradDramTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:536
static CK_TILE_HOST_DEVICE constexpr auto MakeSGradRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1642
static CK_TILE_HOST_DEVICE constexpr auto GetTransposedAlignmentK()
Definition block_fmha_bwd_pipeline_default_policy.hpp:341
static CK_TILE_HOST_DEVICE constexpr auto MakePreOGradDramTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:672
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledKRegWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1104
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeLSE()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1854
static CK_TILE_HOST_DEVICE constexpr auto MakeVRegBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1071
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentK()
Definition block_fmha_bwd_pipeline_default_policy.hpp:228
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentOGrad()
Definition block_fmha_bwd_pipeline_default_policy.hpp:267
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeQ()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1818
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentVGrad()
Definition block_fmha_bwd_pipeline_default_policy.hpp:317
static CK_TILE_DEVICE constexpr void SGradTFromGemm2CToGemm3A(SGradTOutTensor &dst_out, const SGradInTensor &ds_in)
Definition block_fmha_bwd_pipeline_default_policy.hpp:1726
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeK()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1836
static CK_TILE_HOST_DEVICE constexpr auto MakeOGradRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1465
static CK_TILE_HOST_DEVICE constexpr auto MakeXLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:854
static CK_TILE_HOST_DEVICE constexpr auto MakeBiasLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1798
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackBias()
Definition block_fmha_bwd_pipeline_default_policy.hpp:769
static CK_TILE_HOST_DEVICE constexpr auto MakeOGradTLdsReadBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1547
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeD()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1863
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackSGrad()
Definition block_fmha_bwd_pipeline_default_policy.hpp:793
static CK_TILE_HOST_DEVICE constexpr auto MakeVLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1046
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentPostQGrad()
Definition block_fmha_bwd_pipeline_default_policy.hpp:383
static constexpr auto swap_last2
Definition block_fmha_bwd_pipeline_default_policy.hpp:25
static CK_TILE_HOST_DEVICE constexpr auto MakeXLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:801
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackK()
Definition block_fmha_bwd_pipeline_default_policy.hpp:751
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackOGrad()
Definition block_fmha_bwd_pipeline_default_policy.hpp:781
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentQ()
Definition block_fmha_bwd_pipeline_default_policy.hpp:209
static CK_TILE_HOST_DEVICE constexpr auto GetTransposedAlignmentBias()
Definition block_fmha_bwd_pipeline_default_policy.hpp:364
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1921
static CK_TILE_HOST_DEVICE constexpr auto MakePTRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1598
static CK_TILE_HOST_DEVICE constexpr auto MakeSGradLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1632
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeQT()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1826
static CK_TILE_HOST_DEVICE constexpr auto MakePostQGradDramTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:711
static CK_TILE_HOST_DEVICE constexpr auto MakeVDramTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:438
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeV()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1872
static CK_TILE_HOST_DEVICE constexpr auto GetTransposedAlignmentOGrad()
Definition block_fmha_bwd_pipeline_default_policy.hpp:352
static CK_TILE_HOST_DEVICE constexpr auto MakeXTLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:873
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeOGradT()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1890
static CK_TILE_HOST_DEVICE constexpr auto MakeQDramTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:486
static CK_TILE_HOST_DEVICE constexpr auto MakePreODramTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:661
static CK_TILE_HOST_DEVICE constexpr auto MakeXTLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:863
static CK_TILE_HOST_DEVICE constexpr auto GetSGradTQTBlockGemm()
Definition block_fmha_bwd_pipeline_default_policy.hpp:138
static CK_TILE_HOST_DEVICE constexpr auto MakeKLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:988
static CK_TILE_HOST_DEVICE constexpr auto MakeLSEDDramTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:586
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeSGrad()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1899
static CK_TILE_DEVICE constexpr auto GetPTOGradTBlockGemm()
Definition block_fmha_bwd_pipeline_default_policy.hpp:66
static CK_TILE_HOST_DEVICE constexpr auto GetSGradKTBlockGemm()
Definition block_fmha_bwd_pipeline_default_policy.hpp:176
static CK_TILE_HOST_DEVICE constexpr auto MakeBiasTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:614
static CK_TILE_HOST_DEVICE constexpr auto GetTransposedAlignmentQ()
Definition block_fmha_bwd_pipeline_default_policy.hpp:329
static CK_TILE_HOST_DEVICE constexpr auto MakeQTRegSliceBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1319
static CK_TILE_HOST_DEVICE constexpr auto MakeBiasSTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1811
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackOGradT()
Definition block_fmha_bwd_pipeline_default_policy.hpp:787
static CK_TILE_HOST_DEVICE constexpr auto MakeQTLdsReadBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1302
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledQLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1269
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackV()
Definition block_fmha_bwd_pipeline_default_policy.hpp:763
static CK_TILE_DEVICE constexpr void PTFromGemm0CToGemm1A(PTOutTensor &pt_out, const PInTensor &p_in)
Definition block_fmha_bwd_pipeline_default_policy.hpp:1676
static CK_TILE_HOST_DEVICE constexpr auto MakePreXDramTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:640
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackQT()
Definition block_fmha_bwd_pipeline_default_policy.hpp:745
static CK_TILE_HOST_DEVICE constexpr auto MakeQLdsBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1199
static CK_TILE_HOST_DEVICE constexpr auto MakeKTRegBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1165
static CK_TILE_HOST_DEVICE constexpr auto MakeKRegBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1013
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledQRegWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1257
static CK_TILE_HOST_DEVICE constexpr auto MakeShuffledBiasTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1776
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentKGrad()
Definition block_fmha_bwd_pipeline_default_policy.hpp:305
static CK_TILE_HOST_DEVICE constexpr auto MakeLSEDLdsWriteBlockDescriptor()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1388
static CK_TILE_HOST_DEVICE constexpr auto GetSmemKPackQ()
Definition block_fmha_bwd_pipeline_default_policy.hpp:739
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentBias()
Definition block_fmha_bwd_pipeline_default_policy.hpp:286
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSizeKT()
Definition block_fmha_bwd_pipeline_default_policy.hpp:1845
static CK_TILE_HOST_DEVICE constexpr auto MakePostQGradAccDramTileDistribution()
Definition block_fmha_bwd_pipeline_default_policy.hpp:683
static CK_TILE_HOST_DEVICE constexpr auto GetAlignmentV()
Definition block_fmha_bwd_pipeline_default_policy.hpp:247
Definition block_gemm_areg_breg_creg_v1_custom_policy.hpp:16
Definition block_gemm_areg_breg_creg_v1.hpp:18
Definition block_gemm_problem.hpp:18
Definition tile_gemm_shape.hpp:17
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192