device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp Source File

device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp Source File#

Composable Kernel: device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp Source File
device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24template <typename ALayout,
25 typename BLayout,
26 typename CLayout,
27 typename ADataType,
28 typename BDataType,
29 typename CDataType,
30 typename GemmAccDataType,
31 typename CShuffleDataType,
32 typename AElementwiseOperation,
33 typename BElementwiseOperation,
34 typename CElementwiseOperation,
35 GemmSpecialization GemmSpec,
36 index_t BlockSize,
37 index_t MPerBlock,
38 index_t NPerBlock,
39 index_t KPerBlock,
40 index_t AK1,
41 index_t BK1,
42 index_t MPerXDL,
43 index_t NPerXDL,
44 index_t MXdlPerWave,
45 index_t NXdlPerWave,
46 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
47 typename ABlockTransferThreadClusterArrangeOrder,
48 typename ABlockTransferSrcAccessOrder,
49 index_t ABlockTransferSrcVectorDim,
50 index_t ABlockTransferSrcScalarPerVector,
51 index_t ABlockTransferDstScalarPerVector_AK1,
52 bool ABlockLdsExtraM,
53 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
54 typename BBlockTransferThreadClusterArrangeOrder,
55 typename BBlockTransferSrcAccessOrder,
56 index_t BBlockTransferSrcVectorDim,
57 index_t BBlockTransferSrcScalarPerVector,
58 index_t BBlockTransferDstScalarPerVector_BK1,
59 bool BBlockLdsExtraN,
60 index_t CShuffleMXdlPerWavePerShuffle,
61 index_t CShuffleNXdlPerWavePerShuffle,
62 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
63 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
66 typename ComputeTypeA = CDataType,
67 typename ComputeTypeB = ComputeTypeA,
68 bool PermuteA = false,
69 bool PermuteB = false>
71 BLayout,
72 CLayout,
73 ADataType,
74 BDataType,
75 CDataType,
76 AElementwiseOperation,
77 BElementwiseOperation,
78 CElementwiseOperation>
79{
81 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
82 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
83
84 // GridwiseGemm
85 template <index_t NXdlPerWave_>
87 ALayout,
88 BLayout,
89 CLayout,
90 ADataType,
91 BDataType,
92 GemmAccDataType,
93 CShuffleDataType,
94 CDataType,
95 AElementwiseOperation,
96 BElementwiseOperation,
97 CElementwiseOperation,
98 GemmSpec,
99 BlockSize,
100 MPerBlock,
101 NPerBlock,
102 KPerBlock,
103 AK1,
104 BK1,
105 MPerXDL,
106 NPerXDL,
107 MXdlPerWave,
108 NXdlPerWave_,
109 ABlockTransferThreadClusterLengths_AK0_M_AK1,
110 ABlockTransferThreadClusterArrangeOrder,
111 ABlockTransferSrcAccessOrder,
112 ABlockTransferSrcVectorDim,
113 ABlockTransferSrcScalarPerVector,
114 ABlockTransferDstScalarPerVector_AK1,
115 false,
116 ABlockLdsExtraM,
117 BBlockTransferThreadClusterLengths_BK0_N_BK1,
118 BBlockTransferThreadClusterArrangeOrder,
119 BBlockTransferSrcAccessOrder,
120 BBlockTransferSrcVectorDim,
121 BBlockTransferSrcScalarPerVector,
122 BBlockTransferDstScalarPerVector_BK1,
123 false,
124 BBlockLdsExtraN,
125 CShuffleMXdlPerWavePerShuffle,
126 CShuffleNXdlPerWavePerShuffle,
127 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
128 CShuffleBlockTransferScalarPerVector_NPerBlock,
129 BlkGemmPipeSched,
130 BlkGemmPipelineVer,
131 ComputeTypeA,
132 ComputeTypeB,
133 PermuteA,
134 PermuteB>;
137
138 using Argument = typename GridwiseGemm64::Argument;
139
140 static constexpr index_t APackedSize = []() {
142 return 2;
143 else
144 return 1;
145 }();
146
147 static constexpr index_t BPackedSize = []() {
149 return 2;
150 else
151 return 1;
152 }();
153
154 int GetPreShuffleParameters() override { return NPerXDL; }
155
156 // Invoker
157 struct Invoker : public BaseInvoker
158 {
159 template <typename GridwiseGemm>
160 float RunImp(const typename GridwiseGemm::Argument& arg,
161 const StreamConfig& stream_config = StreamConfig{})
162 {
163 if(stream_config.log_level_ > 0)
164 {
165 arg.Print();
166 GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
167 }
168
169 if(!GridwiseGemm::CheckValidity(arg))
170 {
171 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
172 }
173
174 index_t gdx, gdy, gdz;
175 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
176
177 float ave_time = 0;
178
179 index_t k_grain = arg.KBatch * KPerBlock;
180 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
181
182 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
183
184 const auto Run = [&](const auto& kernel) {
185 if(stream_config.flush_cache)
186 {
187 auto arg_ = arg;
188
189 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
190 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
191 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
192 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
193
194 auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
195 sizeof(ADataType) / APackedSize;
196 auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
197 sizeof(BDataType) / BPackedSize;
198
200 arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
201 rotating_mem.Print();
202
203 auto run_flush_cache = [&]() {
204 // flush icache
206 // rotating mem
207 rotating_mem.Next();
208 // clear c mem
209 if(arg_.KBatch > 1)
210 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
211 0,
212 arg_.M * arg_.N * sizeof(CDataType),
213 stream_config.stream_id_));
214 };
215
217 stream_config,
218 run_flush_cache,
219 kernel,
220 dim3(gdx, gdy, gdz),
221 dim3(BlockSize),
222 0,
223 arg_);
224 }
225 else
226 {
227 if(arg.KBatch > 1)
228 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
229 0,
230 arg.M * arg.N * sizeof(CDataType),
231 stream_config.stream_id_));
232
233 ave_time = launch_and_time_kernel(
234 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
235 }
236 };
237
238 constexpr auto estimated_reg_a = MPerBlock * KPerBlock * sizeof(ADataType) / BlockSize /
239 4 * (1 + GridwiseGemm::NWave);
240 constexpr auto estimated_reg_b =
241 NPerBlock * KPerBlock * sizeof(BDataType) / BlockSize / 4 * (2);
242 constexpr auto estimated_reg_c =
243 MPerBlock * NPerBlock * sizeof(GemmAccDataType) / BlockSize / 4;
244 constexpr auto estimated_reg_total =
245 estimated_reg_a + estimated_reg_b + estimated_reg_c;
246
247 constexpr index_t minimum_occupancy = (estimated_reg_total >= 256) ? 1 : 2;
248
249 if(has_main_k_block_loop)
250 {
251 // Tail number always full
252 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
253 {
254 if(arg.KBatch > 1)
255 {
256 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
257 {
259 GridwiseGemm,
260 true,
262 minimum_occupancy,
264 Run(kernel);
265 }
266 else
267 {
269 GridwiseGemm,
270 true,
272 minimum_occupancy,
274 Run(kernel);
275 }
276 }
277 else
278 {
279 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
280 {
282 GridwiseGemm,
283 true,
285 minimum_occupancy,
287 Run(kernel);
288 }
289 else
290 {
292 GridwiseGemm,
293 true,
295 minimum_occupancy,
297 Run(kernel);
298 }
299 }
300 }
301 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v2 ||
302 BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
303 {
304 if(arg.KBatch > 1)
305 {
306 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
307 {
309 GridwiseGemm,
310 true,
312 minimum_occupancy,
314 Run(kernel);
315 }
316 else
317 {
319 GridwiseGemm,
320 true,
322 minimum_occupancy,
324 Run(kernel);
325 }
326 }
327 else
328 {
329 if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
330 {
332 GridwiseGemm,
333 true,
335 minimum_occupancy,
337 Run(kernel);
338 }
339 else
340 {
342 GridwiseGemm,
343 true,
345 minimum_occupancy,
347 Run(kernel);
348 }
349 }
350 }
351 else
352 {
353 throw std::runtime_error("Only support pipeline ver v1, v2, v3 now!");
354 }
355 }
356#if 0
357 else
358 {
359 // Tail number always 1
360 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
361 {
362 if(arg.KBatch > 1)
363 {
364 const auto kernel =
366 false,
368 minimum_occupancy,
370 Run(kernel);
371 }
372 else
373 {
374 const auto kernel =
376 false,
378 minimum_occupancy,
380 Run(kernel);
381 }
382 }
383 }
384#endif
385
386 return ave_time;
387 }
388
390
391 // polymorphic
392 float Run(const BaseArgument* p_arg,
393 const StreamConfig& stream_config = StreamConfig{}) override
394 {
395 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
396 }
397 };
398
399 static constexpr bool IsValidCompilationParameter()
400 {
401 // TODO: properly implement this check
402 return true;
403 }
404
405 static bool IsSupportedArgument(const Argument& arg)
406 {
408 {
409 return false;
410 }
411 if(is_gfx11_supported() && arg.KBatch > 1)
412 {
413 return false;
414 }
415 if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
416 {
417 return false;
418 }
419
420 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
421 GemmSpec == GemmSpecialization::NKPadding ||
422 GemmSpec == GemmSpecialization::MNKPadding ||
423 GemmSpec == GemmSpecialization::KPadding))
424 {
425 return false;
426 }
427
428 if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
429 {
430 return false;
431 }
432
433 if(get_warp_size() == 64)
434 {
435 if constexpr(NXdlPerWave64 > 0)
436 {
438 }
439 }
440 else
441 {
442 if constexpr(NXdlPerWave32 > 0)
443 {
445 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
446 }
447 }
448 return false;
449 }
450
451 // polymorphic
452 bool IsSupportedArgument(const BaseArgument* p_arg) override
453 {
454 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
455 }
456
457 index_t GetKPerBlock() override { return KPerBlock; }
458
459 bool GetPermuteA() override { return PermuteA; }
460 bool GetPermuteB() override { return PermuteB; }
461
462 static auto MakeArgument(const ADataType* p_a,
463 const BDataType* p_b,
464 CDataType* p_c,
465 index_t M,
466 index_t N,
467 index_t K,
468 index_t StrideA,
469 index_t StrideB,
470 index_t StrideC,
471 index_t KBatch,
472 AElementwiseOperation,
473 BElementwiseOperation,
474 CElementwiseOperation)
475 {
476 return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, KBatch};
477 }
478
479 static auto MakeInvoker() { return Invoker{}; }
480
481 // polymorphic
482 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
483 const void* p_b,
484 void* p_c,
485 index_t M,
486 index_t N,
487 index_t K,
488 index_t StrideA,
489 index_t StrideB,
490 index_t StrideC,
491 index_t KBatch,
492 AElementwiseOperation,
493 BElementwiseOperation,
494 CElementwiseOperation) override
495 {
496 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
497 static_cast<const BDataType*>(p_b),
498 static_cast<CDataType*>(p_c),
499 M,
500 N,
501 K,
502 StrideA,
503 StrideB,
504 StrideC,
505 KBatch);
506 }
507
508 // polymorphic
509 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
510 {
511 return std::make_unique<Invoker>(Invoker{});
512 }
513
514 // polymorphic
515 std::string GetTypeString() const override
516 {
517 auto str = std::stringstream();
518
519 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
522
523 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
529
530 // clang-format off
531 str << "DeviceGemmXdlUniversal"
532 << "<"
533 << getGemmSpecializationString(GemmSpec) << ", "
534 << std::string(ALayout::name)[0]
535 << std::string(BLayout::name)[0]
536 << std::string(CLayout::name)[0]
537 << ">"
538 << " BlkSize: "
539 << BlockSize << ", "
540 << "BlkTile: "
541 << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
542 << "WaveTile: "
543 << MPerXDL<<"x"<<NPerXDL << ", "
544 << "WaveMap: "
545 << MXdlPerWave<<"x" << NXdlPerWave<<", "
546 << "VmemReadVec: "
547 << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
548 << "BlkGemmPipelineScheduler: "
549 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
550 << "BlkGemmPipelineVersion: "
551 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
552 << "BlkGemmPipelinePrefetchStages: "
553 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages << ", "
554 << "Kpack: "
555 << GridwiseGemm64::BlockwiseGemmPipe::AMmaKStride;
556 // clang-format on
557
558 return str.str();
559 }
561};
562
563} // namespace device
564} // namespace tensor_operation
565} // namespace ck
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define REGISTER_EXTRA_PRINTING_METHODS
Definition device_base.hpp:47
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
@ Even
Definition blkgemmpipe_scheduler.hpp:34
@ Odd
Definition blkgemmpipe_scheduler.hpp:33
__global__ void kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:75
bool is_xdl_wmma_supported()
Definition host_utility/device_prop.hpp:76
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
__global__ void kernel_gemm_xdl_cshuffle_v3_b_preshuffle(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:36
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:157
Definition data_type.hpp:187
Definition device_base.hpp:197
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:158
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:392
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:160
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:79
GridwiseGemmBase< NXdlPerWave32 > GridwiseGemm32
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:136
static auto MakeInvoker()
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:479
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:509
bool GetPermuteB() override
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:460
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:462
bool GetPermuteA() override
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:459
typename GridwiseGemm64::Argument Argument
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:138
static constexpr index_t APackedSize
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:140
std::string GetTypeString() const override
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:515
GridwiseGemmBase< math::max(NXdlPerWave64, 1)> GridwiseGemm64
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:135
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:81
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:82
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:405
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:399
GridwiseGemm_xdl_cshuffle_v3_b_preshuffle< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, PermuteA, PermuteB > GridwiseGemmBase
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:86
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:452
static constexpr index_t BPackedSize
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:147
index_t GetKPerBlock() override
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:457
int GetPreShuffleParameters() override
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:154
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition device_gemm_xdl_cshuffle_v3_b_preshuffle.hpp:482
Definition flush_cache.hpp:299