wp_pipeline_agmem_bgmem_creg_v2.hpp Source File

wp_pipeline_agmem_bgmem_creg_v2.hpp Source File#

Composable Kernel: wp_pipeline_agmem_bgmem_creg_v2.hpp Source File
wp_pipeline_agmem_bgmem_creg_v2.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10
11namespace ck_tile {
12
13template <typename Problem>
15{
16 static constexpr index_t PrefetchStages = 2;
17 static constexpr index_t PrefillStages = 1;
18 static constexpr index_t GlobalBufferNum = 1;
19 static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
20
21 CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
22
23 CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
24 {
25 return num_loop > PrefetchStages;
26 }
27
29 {
30 return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
31 }
32
33 template <typename RunFunction>
34 CK_TILE_HOST_DEVICE static auto
35 TailHandler(const RunFunction& run_func, bool, TailNumber tail_number)
36 {
37 if(tail_number == TailNumber::Odd)
38 {
39 return run_func(bool_constant<true>{},
41 }
42 else // Even tail number
43 {
44 return run_func(bool_constant<true>{},
46 }
48 }
49};
50
51template <typename Problem, typename PipelinePolicy = UniversalWeightPreshufflePipelineAgBgCrPolicy>
54{
56
60
64
68
71
74
77
78 static constexpr auto config =
79 BlockWeightPreshuffle::BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
80
81 using WG = remove_cvref_t<decltype(config.template at<0>())>;
82
83 static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2
84 static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
85
86 static constexpr index_t BlockSize = Problem::kBlockSize;
87 static constexpr index_t WaveSize = get_warp_size();
88
89 static constexpr index_t kMPerBlock = BlockGemmShape::kM;
90 static constexpr index_t kNPerBlock = BlockGemmShape::kN;
91 static constexpr index_t kKPerBlock = BlockGemmShape::kK;
92
93 // bogus variables to compile grouped gemm (to be removed)
94 static constexpr index_t MPerBlock = BlockGemmShape::kM;
95 static constexpr index_t NPerBlock = BlockGemmShape::kN;
96 static constexpr index_t KPerBlock = BlockGemmShape::kK;
97
98 static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
99 static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
100
101 template <bool IsWave32Host = false>
102 static constexpr index_t GetVectorSizeA()
103 {
104 return PipelinePolicy::template GetVectorSizeA<Problem, IsWave32Host>();
105 }
106 template <bool IsWave32Host = false>
107 static constexpr index_t GetVectorSizeB()
108 {
109 return PipelinePolicy::template GetVectorSizeB<Problem, IsWave32Host>();
110 }
111
112 static constexpr index_t GetVectorSizeC()
113 {
114 return PipelinePolicy::template GetVectorSizeC<Problem>();
115 }
116
117 static constexpr bool kPadM = Problem::kPadM;
118 static constexpr bool kPadN = Problem::kPadN;
119 static constexpr bool kPadK = Problem::kPadK;
120
121 static constexpr index_t kLdsAlignmentInBytes = 16;
122 static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
123
124 static constexpr auto I0 = number<0>();
125 static constexpr auto I1 = number<1>();
126 static constexpr auto I2 = number<2>();
127 static constexpr auto idxM = I0;
128 static constexpr auto idxN = I1;
129 static constexpr auto idxK = I2;
133
134 static constexpr index_t MWarp = config.template at<1>();
135 static constexpr index_t NWarp = config.template at<2>();
136
137 static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WG::kM);
138 static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN);
139 static constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
140
143
146
147 static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
151 static constexpr auto TailNum = Problem::TailNum;
152
153#ifdef __gfx942__
154 static constexpr index_t mfma_per_wg = 2;
155#else
156 static constexpr index_t mfma_per_wg = 1;
157#endif
158 static constexpr index_t dsread_per_wg =
159 max(index_t(WG::kM * WG::kK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize), 1);
160#if defined(__HIP_DEVICE_COMPILE__)
161 static_assert((WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize) %
162 Problem::VectorLoadSize ==
163 0);
164#endif
165 static constexpr index_t dsread_num_perK =
166 WG::kM * WG::kK * sizeof(ADataType) * MIterPerWarp / WaveSize / Problem::VectorLoadSize;
170 static constexpr index_t Aload_rep = dswrite_rep;
171 static constexpr index_t Bload_num_perK = kNPerBlock * WG::kK / NWarp / K1 / WaveSize;
172 static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
173 static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
174
178
179 [[nodiscard]] CK_TILE_HOST static const std::string GetName()
180 {
181 // clang-format off
182 return concat('_', "pipeline_AGmemBGmemCRegV2",
184 concat('x', WG::kM, WG::kN, WG::kK),
186 concat('x', kPadM, kPadN, kPadK));
187
188 // clang-format on
189 }
190
191 static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
192 static constexpr index_t Preshuffle = Problem::Preshuffle;
194
195 CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
196
198 {
199 return PipelinePolicy::template GetSmemSize<Problem>();
200 }
201
202 // dsread_perM: how many LDS reads want to issue in this M-iter
203 // dswrite_perM: how many LDS writes you want to do this M-iter
204 // load_perM: how many global loads VMEM want to do in this M-iter
205 CK_TILE_HOST_DEVICE static constexpr auto
206 SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
207 {
208
209 // Init inst order
210 index_t max_data_inst = dsread_perM > load_perM
211 ? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM)
212 : (load_perM > dswrite_perM ? load_perM : dswrite_perM);
213 index_t sum_data_inst = dsread_perM + load_perM + dswrite_perM;
214 index_t round_data_inst = ck_tile::integer_divide_ceil(sum_data_inst, mfma_perM_perK);
215
216 constexpr int kOrderCap = NIterPerWarp * 10;
217 index_t inst_order[kOrderCap] = {};
218 index_t index = 0;
219#pragma unroll
220 // round-robin
221 // Index: 0 1 2 3 4 5 ...
222 // Value: 1 2 3 1 2 3 ...
223 for(int j = 0; j < max_data_inst; j++)
224 {
225 if(dswrite_perM > j)
226 {
227 inst_order[index] = 1;
228 index++;
229 }
230 if(load_perM > j)
231 {
232 inst_order[index] = 2;
233 index++;
234 }
235 if(dsread_perM > j)
236 {
237 inst_order[index] = 3;
238 index++;
239 }
240 }
241
242// Schedule IGLP
243#pragma unroll
244 for(int j = 0; j < mfma_perM_perK; j++)
245 {
246 index_t inst_idx = 0;
247 if(j == 0)
248 ;
249 else if(j == 1)
250 inst_idx = mfma_perM_perK == 2 ? 1 : mfma_perM_perK - 2;
251 else if(j == 2)
252 inst_idx = mfma_perM_perK - 1;
253 else
254 inst_idx = mfma_perM_perK - j;
255
256 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
257
258#pragma unroll
259 for(int r = 0; r < round_data_inst; r++)
260 {
261 if(r % 2 == 0)
262 {
263 if(inst_order[inst_idx + r * mfma_perM_perK] == 1)
264 {
265 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
266 }
267 if(inst_order[inst_idx + r * mfma_perM_perK] == 2)
268 {
269 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
270 }
271 if(inst_order[inst_idx + r * mfma_perM_perK] == 3)
272 {
273 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
274 }
275 }
276 else
277 {
278 if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 1)
279 {
280 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
281 }
282 if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 2)
283 {
284 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
285 }
286 if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 3)
287 {
288 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
289 }
290 }
291 }
292 }
293 }
294
295 CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
296 {
297 // Keypoint of pipeline optimize is workload balance in time
298 // instruction schedule example(128X256X256, 1X4, 16X16X128):
299 // Iter MNK MFMA ds_read ds_write A_load b_load
300 // -1 M6N0: 57 - 8 - -
301 // -1 M6N1: 58 1 - - -
302 // -1 M6N2: 59 - - 7 -
303 // -1 M6N3: 60 2 - - -
304 // -1 M7N0: 61 - - - -
305 // -1 M7N1: 62 3 - - -
306 // -1 M7N2: 63 - - 8 -
307 // -1 M7N3: 64 4 - - -
308 // 0 M0N0K0: 1 - - - 1
309 // 0 M0N1: 2 5 - - -
310 // 0 M0N2: 3 - - - 2
311 // 0 M0N3: 4 6 - - -
312 // 0 M1N0: 5 - - - 3
313 // 0 M1N1: 6 7 - - -
314 // 0 M1N2: 7 - - - 4
315 // 0 M1N3: 8 8 - - -
316 // 0 M2N0: 9 - - - 5
317 // 0 M2N1: 10 9 - - -
318 // 0 M2N2: 11 - - - 6
319 // 0 M2N3: 12 10 - - -
320 // 0 M3N0: 13 - 1 - 7
321 // 0 M3N1: 14 11 - - -
322 // 0 M3N2: 15 - - - 8
323 // 0 M3N3: 16 12 - - -
324 // 0 M4N0: 17 - 2 - -
325 // 0 M4N1: 18 13 - - -
326 // 0 M4N2: 19 - - 1 -
327 // 0 M4N3: 20 14 - - -
328 // 0 M5N0: 21 - 3 - -
329 // 0 M5N1: 22 15 - - -
330 // 0 M5N2: 23 - - 2 -
331 // 0 M5N3: 24 16 - - -
332 // 0 M6N0: 25 - 4 - -
333 // 0 M6N1: 26 17 - - -
334 // 0 M6N2: 27 - - 3 -
335 // 0 M6N3: 28 18 - - -
336 // 0 M7N0: 29 - - - -
337 // 0 M7N1: 30 19 - - -
338 // 0 M7N2: 31 - - 4 -
339 // 0 M7N3: 32 20 - - -
340 // 0 M0N0K1: 33 - - - 9
341 // 0 M0N1: 34 21 - - -
342 // 0 M0N2: 35 - - - 10
343 // 0 M0N3: 36 22 - - -
344 // 0 M1N0: 37 - - - 11
345 // 0 M1N1: 38 23 - - -
346 // 0 M1N2: 39 - - - 12
347 // 0 M1N3: 40 24 - - -
348 // 0 M2N0: 41 - - - 13
349 // 0 M2N1: 42 25 - - -
350 // 0 M2N2: 43 - - - 14
351 // 0 M2N3: 44 26 - - -
352 // 0 M3N0: 45 - 5 - 15
353 // 0 M3N1: 46 27 - - -
354 // 0 M3N2: 47 - - - 16
355 // 0 M3N3: 48 28 - - -
356 // 0 M4N0: 49 - 6 - -
357 // 0 M4N1: 50 29 - - -
358 // 0 M4N2: 51 - - 5 -
359 // 0 M4N3: 52 30 - - -
360 // 0 M5N0: 53 - 7 - -
361 // 0 M5N1: 54 31 - - -
362 // 0 M5N2: 55 - - 6 -
363 // 0 M5N3: 56 32 - - -
364 // 0 M6N0: 57 - 8 - -
365 // 0 M6N1: 58 1 - - -
366 // 0 M6N2: 59 - - 7 -
367 // 0 M6N3: 60 2 - - -
368 // 0 M7N0: 61 - - - -
369 // 0 M7N1: 62 3 - - -
370 // 0 M7N2: 63 - - 8 -
371 // 0 M7N3: 64 4 - - -
372
373#pragma unroll
374 for(int kIter = 0; kIter < KIterPerWarp; kIter++)
375 {
376#pragma unroll
377 for(int mIter = 0; mIter < MIterPerWarp; mIter++)
378 {
379 index_t dsread_perM = 0;
380 index_t dswrite_perM = 0;
381 index_t load_perM = 0;
382
383 // Calculate ds_read number per M
384 dsread_perM = dsread_per_wg;
385
386 // Calculate ds_write number per M
387 if(mIter == 0)
388 {
389 dswrite_perM =
392 : 0;
393 }
394 else if(mIter >= MIterPerWarp - DsWritePreIssue + 1)
395 {
396 dswrite_perM = 0;
397 }
398 else
399 {
400 dswrite_perM = (dswrite_num_perK -
401 (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
403 : 0;
404 }
405 // Add ds write when ds write data > needed
406 if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
407 {
408 if(mIter == MIterPerWarp - 1 - dswrite_mIter)
409 dswrite_perM = 1;
410 }
411
412 // Calculate buffer_load number per M
413 if(mIter < HalfMIter)
414 {
415 load_perM =
416 ((Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 ? Aload_rep
417 : 0) +
418 ((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
419 : 0);
420 }
421 else
422 {
423 load_perM = (Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0
424 ? Aload_rep
425 : 0;
426 }
427 SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
428 }
429 }
430 // Add Aload when Aload data > needed
431 if(Aload_num_perK == 0)
432 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
433 __builtin_amdgcn_sched_barrier(0);
434 }
435
437 {
438#pragma unroll
439 for(int kIter = 0; kIter < KIterPerWarp; kIter++)
440 {
441#pragma unroll
442 for(int mIter = 0; mIter < MIterPerWarp; mIter++)
443 {
444 index_t dsread_perM = 0;
445 index_t dswrite_perM = 0;
446 index_t load_perM = 0;
447
448 // Calculate ds_read number per M
449 dsread_perM = dsread_per_wg;
450
451 // Calculate ds_write number per M
452 if(mIter == 0)
453 {
454 dswrite_perM =
457 : 0;
458 }
459 else if(mIter >= MIterPerWarp - DsWritePreIssue + 1)
460 {
461 dswrite_perM = 0;
462 }
463 else
464 {
465 dswrite_perM = (dswrite_num_perK -
466 (MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
468 : 0;
469 }
470 // Add ds write when ds write data > needed
471 if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
472 {
473 if(mIter == MIterPerWarp - 1 - dswrite_mIter)
474 dswrite_perM = 1;
475 }
476
477 // Calculate buffer_load number per M
478 if(mIter < HalfMIter)
479 {
480 load_perM =
481 ((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
482 : 0);
483 }
484 SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
485 }
486 }
487 __builtin_amdgcn_sched_barrier(0);
488 }
489
491 {
492#pragma unroll
493 for(int kIter = 0; kIter < KIterPerWarp; kIter++)
494 {
495#pragma unroll
496 for(int mIter = 0; mIter < MIterPerWarp; mIter++)
497 {
498 index_t dsread_perM = 0;
499 index_t dswrite_perM = 0;
500 index_t load_perM = 0;
501
502 // Calculate ds_read number per M
503 if((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
504 dsread_perM = dsread_per_wg;
505
506 SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
507 }
508 }
509 // __builtin_amdgcn_sched_barrier(0);
510 }
511
512 template <TailNumber TailNum,
513 typename ADramBlockWindowTmp,
514 typename BFlatBlockWindowTmp,
515 typename AElementFunction,
516 typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
518 bool>* = nullptr,
519 index_t UnaryOpSize_ = 8>
520 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
521 const AElementFunction& a_element_func,
522 const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
523 index_t num_loop,
524 void* p_smem_ping,
525 void* p_smem_pong) const
526 {
527 static_assert(
528 std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>>,
529 "wrong!");
530
531 static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
532 "wrong!");
533 static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
534 "wrong!");
535
536 constexpr auto MIter_2nd_last = (MIterPerWarp >= 2) ? MIterPerWarp - 2 : MIterPerWarp - 1;
537 const index_t iMWarp = get_warp_id() / NWarp;
538
539 using CWarpDstr = typename WG::CWarpDstr;
540 using CWarpTensor = typename WG::CWarpTensor;
541
542 constexpr auto c_warp_y_lengths =
543 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
544 constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
545
546 __builtin_amdgcn_sched_barrier(0);
547
548 // A tile in LDS
549 ADataType* p_a_lds_ping = static_cast<ADataType*>(p_smem_ping);
550 ADataType* p_a_lds_pong = static_cast<ADataType*>(p_smem_pong);
551
552 constexpr auto a_lds_block_desc =
553 PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
554
555 auto a_lds_block_ping =
556 make_tensor_view<address_space_enum::lds>(p_a_lds_ping, a_lds_block_desc);
557 auto a_lds_block_pong =
558 make_tensor_view<address_space_enum::lds>(p_a_lds_pong, a_lds_block_desc);
559
560 // A DRAM tile window for load
561 auto a_copy_dram_window =
562 make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(),
564 a_dram_block_window_tmp.get_window_origin(),
565 PipelinePolicy::template MakeADramTileDistribution<Problem>());
566
567 auto a_copy_lds_window_ping =
568 make_tile_window(a_lds_block_ping,
570 {0, 0},
571 PipelinePolicy::template MakeADramTileDistribution<Problem>());
572
573 auto a_copy_lds_window_pong =
574 make_tile_window(a_lds_block_pong,
576 {0, 0},
577 PipelinePolicy::template MakeADramTileDistribution<Problem>());
578
579 // ping-pong window for A LDS
580 auto a_warp_window_ping_tmp =
581 make_tile_window(a_lds_block_ping,
583 {iMWarp * WG::kM, 0},
584 make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
585
586 auto a_warp_window_pong_tmp =
587 make_tile_window(a_lds_block_pong,
589 {iMWarp * WG::kM, 0},
590 make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
591
593 statically_indexed_array<decltype(a_warp_window_ping_tmp), KIterPerWarp>,
595 a_warp_windows_ping;
596
598 statically_indexed_array<decltype(a_warp_window_pong_tmp), KIterPerWarp>,
600 a_warp_windows_pong;
601
602 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
603 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
604 a_warp_windows_ping(mIter)(kIter) = a_warp_window_ping_tmp;
605
606 move_tile_window(a_warp_windows_ping(mIter)(kIter),
607 {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
608 });
609 });
610
611 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
612 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
613 a_warp_windows_pong(mIter)(kIter) = a_warp_window_pong_tmp;
614
615 move_tile_window(a_warp_windows_pong(mIter)(kIter),
616 {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter});
617 });
618 });
619
620 // Block GEMM
621 auto block_weight_preshuffle = BlockWeightPreshuffle();
622 // Acc register tile
623 auto c_block_tile = block_weight_preshuffle.MakeCBlockTile();
624
625 // B flat DRAM window for load
626 auto b_flat_distribution =
627 PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
628 auto b_flat_dram_window = // tile_window_with_static_distribution
630 b_flat_dram_block_window_tmp.get_bottom_tensor_view(), // from kernel gemm_pad_views
632 b_flat_dram_block_window_tmp.get_window_origin(),
633 b_flat_distribution);
634
635 // pingpong buffer for B
636 using BTypeToUse =
637 std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
638 using BTileType = decltype(make_static_distributed_tensor<BTypeToUse>(b_flat_distribution));
639
641 statically_indexed_array<decltype(b_flat_dram_window), KIterPerWarp>,
643 b_flat_dram_windows;
644
646 b_warp_tensor_ping;
647
649 b_warp_tensor_pong;
650
651 // Prefetch A0
652 auto a_block_tile = load_tile(a_copy_dram_window);
653 // move A window to next k
654 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
655
656 // prefetch B
657 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
658 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
659 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
660
661 move_tile_window(b_flat_dram_windows(nIter)(kIter),
662 {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
663
665 b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
666 });
667 });
668 // move B window to next flat K
669 move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
670
671 // Prefill A0
672 auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
673 store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
674
675 __builtin_amdgcn_sched_barrier(0);
676
677 // Prefetch A1
678 a_block_tile = load_tile(a_copy_dram_window);
679 // move A window to next k
680 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
681
682 // initialize C
683 tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
684
686
687 // preload A00,A10 from lds
688 statically_indexed_array<decltype(load_tile(a_warp_windows_ping(number<0>{})(number<0>{}))),
689 m_preload>
690 a_warp_tensor;
691
692 static_for<0, m_preload, 1>{}([&](auto loadIter) {
693 constexpr auto mIter = loadIter % MIterPerWarp;
694 constexpr auto kIter = loadIter / MIterPerWarp;
695 a_warp_tensor(loadIter) =
696 load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
697 });
698 __builtin_amdgcn_sched_barrier(0);
699
700 // MAIN LOOP
701 index_t iCounter = (num_loop - 1) / 2;
702 while(iCounter > 0)
703 {
704 // prefetch B(2i+1)
705 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
706 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
707 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
708
709 move_tile_window(b_flat_dram_windows(nIter)(kIter),
710 {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
711
713 b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
714 });
715 });
716
717 // Prefill A(2i+1)
718 a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
719 store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
720
721 // Prefetch A(2i+2)
722 a_block_tile = load_tile(a_copy_dram_window);
723 // move A window to next k
724 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
725
726 // GEMM 2i
727 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
728 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
729 constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
730 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
731 // read C warp tensor from C block tensor
732 CWarpTensor c_warp_tensor;
733
734 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
735 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
736 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
737
738 // warp GEMM
739 WG{}(c_warp_tensor,
740 a_warp_tensor(number<AwarpIter>{}),
741 b_warp_tensor_ping(nIter)(kIter));
742
743 // write C warp tensor into C block tensor
744 c_block_tile.set_y_sliced_thread_data(
745 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
746 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
747 c_warp_tensor.get_thread_buffer());
748
749 __builtin_amdgcn_sched_barrier(0x7F6);
750 });
751 // preload next A from lds
752 if constexpr((kIter * MIterPerWarp + mIter) <
754 {
755 constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
756 constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
757 a_warp_tensor(number<AwarpIter>{}) =
758 load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
759 }
760
761 // barrier
762 if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
763 {
765 }
766 });
767 });
768 // move B window to next flat K
769 move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
770
771 static_for<0, m_preload, 1>{}([&](auto loadIter) {
772 constexpr auto mIter = loadIter % MIterPerWarp;
773 constexpr auto kIter = loadIter / MIterPerWarp;
774 a_warp_tensor(loadIter) =
775 load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
776 });
778
779 // Next K
780
781 // prefetch B(2i+2)
782 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
783 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
784 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
785
786 move_tile_window(b_flat_dram_windows(nIter)(kIter),
787 {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
788
790 b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
791 });
792 });
793
794 // Prefill A(2i+2)
795 a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
796 store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
797
798 // Prefetch A(2i+3)
799 a_block_tile = load_tile(a_copy_dram_window);
800 // move A window to next k
801 move_tile_window(a_copy_dram_window, {0, kKPerBlock});
802
803 // GEMM 2i+1
804 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
805 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
806 constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
807 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
808 // read C warp tensor from C block tensor
809 CWarpTensor c_warp_tensor;
810 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
811 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
812 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
813
814 // warp GEMM
815 WG{}(c_warp_tensor,
816 a_warp_tensor(number<AwarpIter>{}),
817 b_warp_tensor_pong(nIter)(kIter));
818
819 // write C warp tensor into C block tensor
820 c_block_tile.set_y_sliced_thread_data(
821 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
822 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
823 c_warp_tensor.get_thread_buffer());
824
825 __builtin_amdgcn_sched_barrier(0x7F6);
826 });
827 // preload next A from lds
828 if constexpr((kIter * MIterPerWarp + mIter) <
830 {
831 constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
832 constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
833 a_warp_tensor(number<AwarpIter>{}) =
834 load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
835 }
836
837 // barrier
838 if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
839 {
841 }
842 });
843 });
844 // move B window to next flat K
845 move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock});
846
847 static_for<0, m_preload, 1>{}([&](auto loadIter) {
848 constexpr auto mIter = loadIter % MIterPerWarp;
849 constexpr auto kIter = loadIter / MIterPerWarp;
850 a_warp_tensor(loadIter) =
851 load_tile(a_warp_windows_ping(number<mIter>{})(number<kIter>{}));
852 });
854
855 iCounter--;
856 }
857
858 // tail
859 if constexpr(TailNum == TailNumber::Even)
860 {
861 // __builtin_amdgcn_sched_barrier(0);
862 // prefetch B(loopK)
863 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
864 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
865 b_flat_dram_windows(nIter)(kIter) = b_flat_dram_window;
866
867 move_tile_window(b_flat_dram_windows(nIter)(kIter),
868 {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter});
869
871 b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter));
872 });
873 });
874
875 // Prefill A(loopK)
876 a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
877 store_tile(a_copy_lds_window_pong, a_block_tile_tmp);
878
879 // GEMM loopK-1
880 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
881 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
882 constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
883 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
884 // read C warp tensor from C block tensor
885 CWarpTensor c_warp_tensor;
886
887 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
888 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
889 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
890
891 // warp GEMM
892 WG{}(c_warp_tensor,
893 a_warp_tensor(number<AwarpIter>{}),
894 b_warp_tensor_ping(nIter)(kIter));
895
896 // write C warp tensor into C block tensor
897 c_block_tile.set_y_sliced_thread_data(
898 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
899 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
900 c_warp_tensor.get_thread_buffer());
901
902 __builtin_amdgcn_sched_barrier(0x7F6);
903 });
904 // preload next A from lds
905 if constexpr((kIter * MIterPerWarp + mIter) <
907 {
908 constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
909 constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
910 a_warp_tensor(number<AwarpIter>{}) =
911 load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
912 }
913
914 // barrier
915 if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
916 {
918 }
919 });
920 });
921 // TailHotLoopScheduler();
922
923 static_for<0, m_preload, 1>{}([&](auto loadIter) {
924 constexpr auto mIter = loadIter % MIterPerWarp;
925 constexpr auto kIter = loadIter / MIterPerWarp;
926 a_warp_tensor(loadIter) =
927 load_tile(a_warp_windows_pong(number<mIter>{})(number<kIter>{}));
928 });
929
931
932 // GEMM loopK
933 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
934 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
935 constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
936 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
937 // read C warp tensor from C block tensor
938 CWarpTensor c_warp_tensor;
939
940 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
941 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
942 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
943
944 // warp GEMM
945 WG{}(c_warp_tensor,
946 a_warp_tensor(number<AwarpIter>{}),
947 b_warp_tensor_pong(nIter)(kIter));
948
949 // write C warp tensor into C block tensor
950 c_block_tile.set_y_sliced_thread_data(
951 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
952 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
953 c_warp_tensor.get_thread_buffer());
954 });
955 if constexpr((kIter * MIterPerWarp + mIter) <
957 {
958 constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
959 constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
960 a_warp_tensor(number<AwarpIter>{}) =
961 load_tile(a_warp_windows_pong(number<AmIter>{})(number<AkIter>{}));
962 }
963 // barrier
964 if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
965 {
967 }
968 });
969 });
971 }
972 else if constexpr(TailNum == TailNumber::Odd)
973 {
974 // GEMM loopK
975 static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
976 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
977 constexpr auto AwarpIter = (kIter * MIterPerWarp + mIter) % m_preload;
978 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
979 // read C warp tensor from C block tensor
980 CWarpTensor c_warp_tensor;
981
982 c_warp_tensor.get_thread_buffer() = c_block_tile.get_y_sliced_thread_data(
983 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
984 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
985
986 // warp GEMM
987 WG{}(c_warp_tensor,
988 a_warp_tensor(number<AwarpIter>{}),
989 b_warp_tensor_ping(nIter)(kIter));
990
991 // write C warp tensor into C block tensor
992 c_block_tile.set_y_sliced_thread_data(
993 merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
994 merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
995 c_warp_tensor.get_thread_buffer());
996
997 __builtin_amdgcn_sched_barrier(0x7F6);
998 });
999 // preload next A from lds
1000 if constexpr((kIter * MIterPerWarp + mIter) <
1002 {
1003 constexpr auto AmIter = (mIter + m_preload) % MIterPerWarp;
1004 constexpr auto AkIter = (kIter + (mIter + m_preload) / MIterPerWarp);
1005 a_warp_tensor(number<AwarpIter>{}) =
1006 load_tile(a_warp_windows_ping(number<AmIter>{})(number<AkIter>{}));
1007 }
1008
1009 // barrier
1010 if constexpr((kIter == KIterPerWarp - 1) && (mIter == MIter_2nd_last))
1011 {
1013 }
1014 });
1015 });
1017 }
1018
1019 return c_block_tile;
1020 }
1021
1022 // called from universal gemm kernel
1023 template <typename ADramBlockWindowTmp,
1024 typename BFlatBlockWindowTmp,
1025 typename AElementFunction,
1026 typename BElementFunction,
1027 typename std::enable_if_t<is_detected<is_tuple, ADramBlockWindowTmp>::value &&
1029 bool>* = nullptr>
1030 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
1031 [[maybe_unused]] const AElementFunction& a_element_func,
1032 const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
1033 [[maybe_unused]] const BElementFunction& b_element_func,
1034 index_t num_loop,
1035 void* p_smem_ping,
1036 void* p_smem_pong) const
1037 {
1038 return operator()<TailNum>(
1039 a_dram_block_window_tmp[number<0>{}],
1040 [](const ADataType& a) { return a; },
1041 b_flat_dram_block_window_tmp[number<0>{}],
1042 num_loop,
1043 p_smem_ping,
1044 p_smem_pong);
1045 }
1046
1047 // called from general gemm kernel
1048 template <typename ADramBlockWindowTmp,
1049 typename BFlatBlockWindowTmp,
1050 typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
1052 bool>* = nullptr>
1053 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
1054 const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
1055 index_t num_loop,
1056 void* p_smem_ping,
1057 void* p_smem_pong) const
1058 {
1059 return operator()<TailNum>(
1060 a_dram_block_window_tmp,
1061 [](const ADataType& a) { return a; },
1062 b_flat_dram_block_window_tmp,
1063 num_loop,
1064 p_smem_ping,
1065 p_smem_pong);
1066 }
1067
1068 // called from grouped gemm kernel
1069 template <typename ADramBlockWindowTmp,
1070 typename BFlatBlockWindowTmp,
1071 typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
1073 bool>* = nullptr>
1074 CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
1075 const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
1076 index_t num_loop,
1077 TailNumber tail_number,
1078 void* __restrict__ p_smem_0,
1079 void* __restrict__ p_smem_1) const
1080 {
1081 const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
1082 (void)bool_val; // Suppress unused parameter warning
1083 constexpr auto tail_num = tail_num_.value;
1084 constexpr auto PassThrough = [](const auto& x) { return x; };
1085 return operator()<tail_num>(a_dram_block_window_tmp,
1087 b_flat_dram_block_window_tmp,
1088 num_loop,
1089 p_smem_0,
1090 p_smem_1);
1091 };
1092 return Base::TailHandler(RunPipeline, true, tail_number);
1093 }
1094};
1095
1096} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST
Definition config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE void load_int4_tile(WarpTile &dst, const WarpWindow &src)
Definition load_interleaved_pk_type.hpp:46
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition tile_elementwise.hpp:40
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
TailNumber
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:21
@ Even
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:24
@ Odd
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:23
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition type_traits.hpp:67
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
ck_tile::element_wise::PassThrough PassThrough
Definition grouped_convolution_utils.hpp:47
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition concat.hpp:43
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
tuple_array< T, N > statically_indexed_array
Definition tile/core/container/statically_indexed_array.hpp:16
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:15
static constexpr bool UsePersistentKernel
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:19
static constexpr index_t PrefillStages
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:17
static constexpr index_t PrefetchStages
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:16
static CK_TILE_HOST_DEVICE auto TailHandler(const RunFunction &run_func, bool, TailNumber tail_number)
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:35
static CK_TILE_HOST_DEVICE constexpr auto TransposeC()
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:21
static CK_TILE_HOST_DEVICE constexpr bool BlockHasHotloop(index_t num_loop)
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:23
static constexpr index_t GlobalBufferNum
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:18
static CK_TILE_HOST_DEVICE constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:28
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:54
remove_cvref_t< typename BlockGemmShape::BlockTile > BlockTile
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:130
static constexpr index_t mfma_per_wg
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:156
static constexpr bool DoubleSmemBuffer
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:191
static constexpr index_t GetVectorSizeB()
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:107
static constexpr index_t Aload_rep
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:170
static constexpr index_t DsReadPreload
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:84
static constexpr index_t GetVectorSizeA()
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:102
static constexpr index_t dswrite_kIter
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:177
remove_cvref_t< typename Problem::AsDataTypeTuple > AsDataType
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:57
static CK_TILE_HOST const std::string GetName()
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:179
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BFlatBlockWindowTmp &b_flat_dram_block_window_tmp, const BElementFunction &b_element_func, index_t num_loop, void *p_smem_ping, void *p_smem_pong) const
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:1030
remove_cvref_t< typename Problem::BsDataTypeTuple > BsDataType
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:58
remove_cvref_t< std::tuple_element_t< 0, AsLayout > > ALayout
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:69
static CK_TILE_HOST_DEVICE constexpr auto HotLoopScheduler()
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:295
static constexpr index_t DsWritePreIssue
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:83
static constexpr index_t NWarp
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:135
static constexpr bool kPadN
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:118
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BFlatBlockWindowTmp &b_flat_dram_block_window_tmp, index_t num_loop, TailNumber tail_number, void *__restrict__ p_smem_0, void *__restrict__ p_smem_1) const
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:1074
static constexpr index_t MWarp
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:134
static constexpr bool kPadK
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:119
static constexpr auto idxK
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:129
static constexpr index_t MIterPerWarp
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:137
static constexpr index_t m_preload
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:148
remove_cvref_t< std::tuple_element_t< 0, AsDataType > > ADataType
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:72
static constexpr index_t NIterPerWarp
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:138
static constexpr bool kPadM
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:117
BaseWeightPreshufflePipelineAGmemBGmemCRegV2< Problem > Base
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:55
static constexpr index_t dswrite_num_perK
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:167
static constexpr index_t Bload_num_perK
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:171
remove_cvref_t< std::tuple_element_t< 0, BsLayout > > BLayout
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:70
remove_cvref_t< decltype(config.template at< 0 >())> WG
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:81
remove_cvref_t< typename BlockGemmShape::BlockWarps > BlockWarps
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:131
static constexpr index_t mfma_perM_perK
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:175
static constexpr index_t GetVectorSizeC()
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:112
remove_cvref_t< typename Problem::BElementWise > BElementWise
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:62
static constexpr index_t KPerBlock
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:96
static constexpr index_t KIterPerWarp
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:139
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:197
remove_cvref_t< decltype(PipelinePolicy::template GetBlockWeightPreshuffle< Problem >())> BlockWeightPreshuffle
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:75
remove_cvref_t< typename Problem::CDataType > CDataType
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:59
static CK_TILE_HOST_DEVICE constexpr auto SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:206
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const AElementFunction &a_element_func, const BFlatBlockWindowTmp &b_flat_dram_block_window_tmp, index_t num_loop, void *p_smem_ping, void *p_smem_pong) const
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:520
static constexpr auto idxM
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:127
static constexpr index_t MPerBlock
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:94
static constexpr index_t KFlatPerBlockPerIter
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:141
static constexpr index_t flatKPerWarp
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:98
static constexpr auto I0
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:124
static constexpr index_t NPerBlock
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:95
static constexpr index_t NumWaveGroups
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:122
static constexpr auto config
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:78
static constexpr index_t kNPerBlock
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:90
static constexpr index_t dswrite_rep
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:168
static constexpr index_t flatNPerWarp
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:99
static constexpr index_t Aload_num_perK
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:169
static constexpr index_t BlockSize
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:86
static CK_TILE_HOST_DEVICE constexpr auto LastHotLoopScheduler()
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:490
remove_cvref_t< typename BlockGemmShape::WarpTile > WarpTile
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:132
static constexpr auto I2
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:126
remove_cvref_t< std::tuple_element_t< 0, BsDataType > > BDataType
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:73
static constexpr index_t KPerBlockPerIter
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:145
remove_cvref_t< typename Problem::AElementWise > AElementWise
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:61
static constexpr index_t NFlatPerBlockPerIter
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:142
static constexpr index_t Bload_rep
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:173
static constexpr index_t kKPerBlock
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:91
static constexpr index_t K1
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:147
static constexpr index_t MPerBlockPerIter
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:144
remove_cvref_t< typename Problem::BlockGemmShape > BlockGemmShape
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:63
remove_cvref_t< typename Problem::CLayout > CLayout
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:67
static constexpr index_t kMPerBlock
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:89
static constexpr auto TailNum
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:151
remove_cvref_t< typename Problem::AsLayoutTuple > AsLayout
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:65
static constexpr index_t Preshuffle
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:192
static constexpr index_t dsread_num_perK
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:165
static constexpr index_t kLdsAlignmentInBytes
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:121
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp &a_dram_block_window_tmp, const BFlatBlockWindowTmp &b_flat_dram_block_window_tmp, index_t num_loop, void *p_smem_ping, void *p_smem_pong) const
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:1053
static constexpr index_t WaveSize
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:87
static constexpr index_t dswrite_mIter
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:176
static constexpr auto idxN
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:128
static CK_TILE_HOST_DEVICE constexpr auto Last2ndHotLoopScheduler()
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:436
static constexpr index_t HalfMIter
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:172
static constexpr index_t dsread_per_wg
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:158
remove_cvref_t< typename Problem::BsLayoutTuple > BsLayout
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:66
static constexpr auto I1
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:125
static CK_TILE_HOST_DEVICE constexpr auto TransposeC()
Definition wp_pipeline_agmem_bgmem_creg_v2.hpp:195
Definition tile/core/numeric/integral_constant.hpp:30
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43