device_gemm_xdl_cshuffle_v3_mx.hpp Source File

device_gemm_xdl_cshuffle_v3_mx.hpp Source File#

Composable Kernel: device_gemm_xdl_cshuffle_v3_mx.hpp Source File
device_gemm_xdl_cshuffle_v3_mx.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
10
21
22namespace ck {
23namespace tensor_operation {
24namespace device {
25
26// clang-format off
103// clang-format on
104template <typename ALayout,
105 typename BLayout,
106 typename CLayout,
107 typename ADataType,
108 typename AScaleDataType,
109 typename BDataType,
110 typename BScaleDataType,
111 typename CDataType,
112 typename GemmAccDataType, // TODO: always float
113 typename CShuffleDataType,
114 typename AElementwiseOperation,
115 typename BElementwiseOperation,
116 typename CElementwiseOperation,
117 GemmSpecialization GemmSpec,
118 index_t ScaleBlockSize, // Scaling block size
119 index_t BlockSize, // Thread block size
120 index_t MPerBlock,
121 index_t NPerBlock,
122 index_t KPerBlock, // multiply with packed_size_v to get the actual KPerBlock
123 index_t AK1,
124 index_t BK1,
125 index_t MPerXDL,
126 index_t NPerXDL,
127 index_t MXdlPerWave,
128 index_t NXdlPerWave,
129 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
130 typename ABlockTransferThreadClusterArrangeOrder,
131 typename ABlockTransferSrcAccessOrder,
132 index_t ABlockTransferSrcVectorDim,
133 index_t ABlockTransferSrcScalarPerVector,
134 index_t ABlockTransferDstScalarPerVector_AK1,
135 bool ABlockLdsExtraM,
136 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
137 typename BBlockTransferThreadClusterArrangeOrder,
138 typename BBlockTransferSrcAccessOrder,
139 index_t BBlockTransferSrcVectorDim,
140 index_t BBlockTransferSrcScalarPerVector,
141 index_t BBlockTransferDstScalarPerVector_BK1,
142 bool BBlockLdsExtraN,
143 index_t CShuffleMXdlPerWavePerShuffle,
144 index_t CShuffleNXdlPerWavePerShuffle,
145 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
146 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
149 typename ComputeTypeA =
150 ADataType, // XXX: These should always be the same as ADataType and BDataType
151 typename ComputeTypeB =
152 BDataType // TODO: Hardcode them and remove from the list of template parameters
153 >
155 BLayout,
156 CLayout,
157 ADataType,
158 AScaleDataType,
159 BDataType,
160 BScaleDataType,
161 CDataType,
162 ScaleBlockSize,
163 AElementwiseOperation,
164 BElementwiseOperation,
165 CElementwiseOperation>
166{
168 static constexpr auto NXdlPerWave64 = GetNXdlPerWave<true>();
169 static constexpr auto NXdlPerWave32 = GetNXdlPerWave<false>();
170
171 // GridwiseGemm
172 template <index_t NXdlPerWave_>
174 ALayout,
175 BLayout,
176 CLayout,
177 ADataType,
178 AScaleDataType,
179 BDataType,
180 BScaleDataType,
181 GemmAccDataType,
182 CShuffleDataType,
183 CDataType,
184 AElementwiseOperation,
185 BElementwiseOperation,
186 CElementwiseOperation,
187 GemmSpec,
188 ScaleBlockSize,
189 BlockSize,
190 MPerBlock,
191 NPerBlock,
192 KPerBlock,
193 AK1,
194 BK1,
195 MPerXDL,
196 NPerXDL,
197 MXdlPerWave,
198 NXdlPerWave_,
199 ABlockTransferThreadClusterLengths_AK0_M_AK1,
200 ABlockTransferThreadClusterArrangeOrder,
201 ABlockTransferSrcAccessOrder,
202 ABlockTransferSrcVectorDim,
203 ABlockTransferSrcScalarPerVector,
204 ABlockTransferDstScalarPerVector_AK1,
205 false,
206 ABlockLdsExtraM,
207 BBlockTransferThreadClusterLengths_BK0_N_BK1,
208 BBlockTransferThreadClusterArrangeOrder,
209 BBlockTransferSrcAccessOrder,
210 BBlockTransferSrcVectorDim,
211 BBlockTransferSrcScalarPerVector,
212 BBlockTransferDstScalarPerVector_BK1,
213 false,
214 BBlockLdsExtraN,
215 CShuffleMXdlPerWavePerShuffle,
216 CShuffleNXdlPerWavePerShuffle,
217 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
218 CShuffleBlockTransferScalarPerVector_NPerBlock,
219 BlkGemmPipeSched,
220 BlkGemmPipelineVer,
221 ComputeTypeA,
222 ComputeTypeB>;
223 template <index_t NXdlPerWave_>
225 ALayout,
226 BLayout,
227 CLayout,
228 ADataType,
229 AScaleDataType,
230 BDataType,
231 BScaleDataType,
232 GemmAccDataType,
233 CShuffleDataType,
234 CDataType,
235 AElementwiseOperation,
236 BElementwiseOperation,
237 CElementwiseOperation,
238 GemmSpec,
239 ScaleBlockSize,
240 BlockSize,
241 MPerBlock,
242 NPerBlock,
243 KPerBlock,
244 AK1,
245 BK1,
246 MPerXDL,
247 NPerXDL,
248 MXdlPerWave,
249 NXdlPerWave_,
250 ABlockTransferThreadClusterLengths_AK0_M_AK1,
251 ABlockTransferThreadClusterArrangeOrder,
252 ABlockTransferSrcAccessOrder,
253 ABlockTransferSrcVectorDim,
254 ABlockTransferSrcScalarPerVector,
255 ABlockTransferDstScalarPerVector_AK1,
256 false,
257 ABlockLdsExtraM,
258 BBlockTransferThreadClusterLengths_BK0_N_BK1,
259 BBlockTransferThreadClusterArrangeOrder,
260 BBlockTransferSrcAccessOrder,
261 BBlockTransferSrcVectorDim,
262 BBlockTransferSrcScalarPerVector,
263 BBlockTransferDstScalarPerVector_BK1,
264 false,
265 BBlockLdsExtraN,
266 CShuffleMXdlPerWavePerShuffle,
267 CShuffleNXdlPerWavePerShuffle,
268 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
269 CShuffleBlockTransferScalarPerVector_NPerBlock,
270 BlkGemmPipeSched,
271 BlkGemmPipelineVer,
272 ComputeTypeA,
273 ComputeTypeB>;
274
283
284 using Argument = typename GridwiseGemm64::Argument;
285
286 // Invoker
287 struct Invoker : public BaseInvoker
288 {
289 template <typename GridwiseGemm>
290 float RunImp(const typename GridwiseGemm::Argument& arg,
291 const StreamConfig& stream_config = StreamConfig{})
292 {
293 if(stream_config.log_level_ > 0)
294 {
295 arg.Print();
296 GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
297 }
298
299 if(!GridwiseGemm::CheckValidity(arg))
300 {
301 throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
302 }
303
304 index_t gdx, gdy, gdz;
305 std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
306
307 float ave_time = 0;
308
309 index_t k_grain = arg.KBatch * KPerBlock;
310 index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
311
312 const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
313
314 const auto Run = [&](const auto& kernel) {
315 if(stream_config.flush_cache)
316 {
317 auto arg_ = arg;
318
319 const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
320 arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
321 const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
322 arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
323
324 auto size_a_buffer =
325 a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType);
326 auto size_b_buffer =
327 b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType);
328
330 arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
331 rotating_mem.Print();
332
333 auto run_flush_cache = [&]() {
334 // flush icache
336 // rotating mem
337 rotating_mem.Next();
338 // clear c mem
339 if(arg_.KBatch > 1)
340 hipGetErrorString(hipMemsetAsync(arg_.p_c_grid,
341 0,
342 arg_.M * arg_.N * sizeof(CDataType),
343 stream_config.stream_id_));
344 };
345
347 stream_config,
348 run_flush_cache,
349 kernel,
350 dim3(gdx, gdy, gdz),
351 dim3(BlockSize),
352 0,
353 arg_);
354 }
355 else
356 {
357 if(arg.KBatch > 1)
358 hipGetErrorString(hipMemsetAsync(arg.p_c_grid,
359 0,
360 arg.M * arg.N * sizeof(CDataType),
361 stream_config.stream_id_));
362
363 ave_time = launch_and_time_kernel(
364 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
365 }
366 };
367
368 // TODO: Check if this is the right algorithm for minimum_occupancy
369 constexpr index_t minimum_occupancy =
370 BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave
371 ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 &&
372 MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2)
373 ? 2
374 : 1
375 : 2;
376
377 constexpr auto TailNumChoices = []() {
378 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
379 return Tuple<constant<TailNumber::Full>>{};
380 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
381 return Tuple<constant<TailNumber::Even>, constant<TailNumber::Odd>>{};
382 else
383 static_assert(false, "Unexpected BlkGemmPipelineVer!");
384 }();
385 constexpr bool Use2LDS = []() {
386 if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
387 return false;
388 else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
389 return true;
390 else
391 static_assert(false, "Unexpected BlkGemmPipelineVer!");
392 }();
393 const TailNumber tail_num = GridwiseGemm::CalculateKBlockLoopTailNum(K_split);
394 using BoolChoices = Tuple<ck::true_type, ck::false_type>;
395 static_for_product<BoolChoices,
396 BoolChoices,
397 remove_cvref_t<decltype(TailNumChoices)>>{}(
398 [&](auto mainloop_choice, auto KBatch_cond_choice, auto tail_num_choice) {
399 constexpr auto CGlobalMemoryDataOperation =
400 KBatch_cond_choice.value ? InMemoryDataOperationEnum::AtomicAdd
402 if(mainloop_choice.value == has_main_k_block_loop &&
403 KBatch_cond_choice.value == (arg.KBatch > 1) &&
404 tail_num_choice.value == tail_num)
405 {
406 const auto kernel = kernel_gemm_xdl_cshuffle_v3_mx< //
407 Use2LDS,
408 GridwiseGemm,
409 mainloop_choice.value,
410 CGlobalMemoryDataOperation,
411 minimum_occupancy,
412 tail_num_choice.value>;
413 Run(kernel);
414 }
415 });
416 return ave_time;
417 }
418
420 // polymorphic
421 float Run(const BaseArgument* p_arg,
422 const StreamConfig& stream_config = StreamConfig{}) override
423 {
424 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
425 }
426 };
427
428 static constexpr bool IsValidCompilationParameter()
429 {
430 static_assert(is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>(),
431 "Only microscaling formats are supported for ADataType and BDataType");
432
433 static_assert(ScaleBlockSize == 32, "Only ScaleBlockSize 32 is supported");
434
436 "ComputeTypeA and ComputeTypeB must be the same as ADataType and BDataType");
437
438 return true;
439 }
440
441 static bool IsSupportedArgument(const Argument& arg)
442 {
443 if constexpr(!IsValidCompilationParameter())
444 {
445 return false;
446 }
447
448 if(ck::get_device_name() != "gfx950")
449 {
450 return false;
451 }
452
453 if(!is_bf16_atomic_supported() && std::is_same_v<CDataType, ck::bhalf_t> && arg.KBatch > 1)
454 {
455 return false;
456 }
457
458 if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
459 GemmSpec == GemmSpecialization::NKPadding ||
460 GemmSpec == GemmSpecialization::MNKPadding ||
461 GemmSpec == GemmSpecialization::KPadding))
462 {
463 return false;
464 }
465
466 if(get_warp_size() == 64)
467 {
468 if constexpr(NXdlPerWave64 > 0)
469 {
470 return GridwiseGemm64::CheckValidity(arg);
471 }
472 }
473 else
474 {
475 if constexpr(NXdlPerWave32 > 0)
476 {
477 return GridwiseGemm32::CheckValidity(
478 reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg));
479 }
480 }
481 return false;
482 }
483
484 // polymorphic
485 bool IsSupportedArgument(const BaseArgument* p_arg) override
486 {
487 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
488 }
489
490 static auto MakeArgument(const ADataType* p_a,
491 const AScaleDataType* p_a_scale,
492 const BDataType* p_b,
493 const BScaleDataType* p_b_scale,
494 CDataType* p_c,
495 index_t M,
496 index_t N,
497 index_t K,
498 index_t StrideA,
499 index_t StrideScaleA,
500 index_t StrideB,
501 index_t StrideScaleB,
502 index_t StrideC,
503 index_t KBatch,
504 AElementwiseOperation a_element_op,
505 BElementwiseOperation b_element_op,
506 CElementwiseOperation c_element_op)
507 {
508 return Argument{p_a,
509 p_a_scale,
510 p_b,
511 p_b_scale,
512 p_c,
513 M,
514 N,
515 K,
516 StrideA,
517 StrideScaleA,
518 StrideB,
519 StrideScaleB,
520 StrideC,
521 KBatch,
522 a_element_op,
523 b_element_op,
524 c_element_op};
525 }
526
527 static auto MakeInvoker() { return Invoker{}; }
528
529 // polymorphic
530 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
531 const void* p_a_scale,
532 const void* p_b,
533 const void* p_b_scale,
534 void* p_c,
535 ck::index_t M,
536 ck::index_t N,
537 ck::index_t K,
538 ck::index_t StrideA,
539 ck::index_t StrideScaleA,
540 ck::index_t StrideB,
541 ck::index_t StrideScaleB,
542 ck::index_t StrideC,
543 ck::index_t KBatch,
544 AElementwiseOperation a_element_op,
545 BElementwiseOperation b_element_op,
546 CElementwiseOperation c_element_op) override
547 {
548 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
549 static_cast<const AScaleDataType*>(p_a_scale),
550 static_cast<const BDataType*>(p_b),
551 static_cast<const BScaleDataType*>(p_b_scale),
552 static_cast<CDataType*>(p_c),
553 M,
554 N,
555 K,
556 StrideA,
557 StrideScaleA,
558 StrideB,
559 StrideScaleB,
560 StrideC,
561 KBatch,
562 a_element_op,
563 b_element_op,
564 c_element_op);
565 }
566
567 // polymorphic
568 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
569 {
570 return std::make_unique<Invoker>(Invoker{});
571 }
572
573 // polymorphic
574 std::string GetTypeString() const override
575 {
576 auto str = std::stringstream();
577
578 std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
581
582 std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
588
589 // clang-format off
590 str << "DeviceGemmMX_Xdl_CShuffleV3"
591 << "<"
592 << getGemmSpecializationString(GemmSpec) << ", "
593 << std::string(ALayout::name)[0]
594 << std::string(BLayout::name)[0]
595 << std::string(CLayout::name)[0]
596 << ">"
597 << " BlkSize: "
598 << BlockSize << ", "
599 << "BlkTile: "
600 << MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
601 << "WaveTile: "
602 << MPerXDL<<"x"<<NPerXDL << ", "
603 << "WaveMap: "
604 << MXdlPerWave<<"x" << NXdlPerWave<<", "
605 << "VmemReadVec: "
606 << ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
607 << "BlkGemmPipelineScheduler: "
608 << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
609 << "BlkGemmPipelineVersion: "
610 << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
611 << "BlkGemmPipelinePrefetchStages: "
612 << GridwiseGemm64::BlockwiseGemmPipe::PrefetchStages << ", "
613 << "Kpack: "
614 << GridwiseGemm64::BlockwiseGemmPipe::AMmaKStride << ", "
615 << "ScaleBlockSize: "
616 << ScaleBlockSize;
617 // clang-format on
618
619 return str.str();
620 }
622};
623
624} // namespace device
625} // namespace tensor_operation
626} // namespace ck
#define INVOKER_RUN3_IMPL
Definition device_base.hpp:114
#define REGISTER_EXTRA_PRINTING_METHODS
Definition device_base.hpp:47
#define GET_NXDL_PER_WAVE_IMPL
Definition device_base.hpp:81
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition convolution_backward_data_specialization.hpp:8
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition gemm_specialization.hpp:32
GemmSpecialization
Definition gemm_specialization.hpp:11
@ MKPadding
Definition gemm_specialization.hpp:18
@ KPadding
Definition gemm_specialization.hpp:16
@ MNKPadding
Definition gemm_specialization.hpp:20
@ NKPadding
Definition gemm_specialization.hpp:19
Definition convolution_backward_data_specialization.hpp:7
void flush_icache()
Definition flush_cache.hpp:383
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, GemmArgs &gemm_args, Args... args)
Definition flush_cache.hpp:398
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
typename conditional< predicate, X, Y >::type conditional_t
Definition utility/functional.hpp:115
@ Set
Definition ck.hpp:278
@ AtomicAdd
Definition ck.hpp:279
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
BlockGemmPipelineVersion
Definition blkgemmpipe_scheduler.hpp:12
@ v2
Definition blkgemmpipe_scheduler.hpp:15
@ v3
Definition blkgemmpipe_scheduler.hpp:16
@ v5
Definition blkgemmpipe_scheduler.hpp:18
@ v4
Definition blkgemmpipe_scheduler.hpp:17
@ v1
Definition blkgemmpipe_scheduler.hpp:14
std::string get_device_name()
Definition host_utility/device_prop.hpp:19
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
constexpr bool is_same_v
Definition type.hpp:283
__global__ enable_if_t<!Use2LDS, void > kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg)
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:40
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
@ Interwave
Definition blkgemmpipe_scheduler.hpp:27
bool is_bf16_atomic_supported()
Definition host_utility/device_prop.hpp:108
Definition ck/stream_config.hpp:10
Definition gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp:156
Definition gridwise_gemm_xdl_cshuffle_v3_mx.hpp:156
Definition device_base.hpp:197
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:288
INVOKER_RUN3_IMPL float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:421
float RunImp(const typename GridwiseGemm::Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:290
WIP: Implements XDL CShuffle V3 GEMM for microscale-compliant data types.
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:166
static constexpr bool IsValidCompilationParameter()
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:428
static bool IsSupportedArgument(const Argument &arg)
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:441
conditional_t< !is_same_v< BLayout, tensor_layout::gemm::MFMA >, GridwiseGemmMXBase< math::max(NXdlPerWave64, 1)>, GridwiseGemmMXBPreshuffleBase< math::max(NXdlPerWave64, 1)> > GridwiseGemm64
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:275
std::string GetTypeString() const override
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:574
static constexpr auto NXdlPerWave32
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:169
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideScaleA, ck::index_t StrideB, ck::index_t StrideScaleB, ck::index_t StrideC, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:530
static auto MakeArgument(const ADataType *p_a, const AScaleDataType *p_a_scale, const BDataType *p_b, const BScaleDataType *p_b_scale, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideScaleA, index_t StrideB, index_t StrideScaleB, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:490
GridwiseGemmMX_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > GridwiseGemmMXBase
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:173
static GET_NXDL_PER_WAVE_IMPL constexpr auto NXdlPerWave64
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:168
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:568
static auto MakeInvoker()
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:527
typename GridwiseGemm64::Argument Argument
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:284
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:485
GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< ALayout, BLayout, CLayout, ADataType, AScaleDataType, BDataType, BScaleDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, ScaleBlockSize, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave_, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > GridwiseGemmMXBPreshuffleBase
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:224
conditional_t< !is_same_v< BLayout, tensor_layout::gemm::MFMA >, GridwiseGemmMXBase< NXdlPerWave32 >, GridwiseGemmMXBPreshuffleBase< NXdlPerWave32 > > GridwiseGemm32
Definition device_gemm_xdl_cshuffle_v3_mx.hpp:279
Definition device_gemm_mx.hpp:25
Definition flush_cache.hpp:299