device_fpAintB_gemm_wmma.hpp Source File

device_fpAintB_gemm_wmma.hpp Source File#

Composable Kernel: device_fpAintB_gemm_wmma.hpp Source File
device_fpAintB_gemm_wmma.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <sstream>
8
19
20namespace ck {
21namespace tensor_operation {
22namespace device {
23
24// 1. DequantB(K, N) = int2fp(B(K, N)) * scale(1, N)
25// 2. C(M, N) = A(M, K) * DequantB(K, N)
26
27template <typename ALayout,
28 typename BLayout,
29 typename CLayout,
30 typename ADataType,
31 typename BDataType,
32 typename ScaleDataType,
33 typename CDataType,
34 typename AccDataType,
35 typename CShuffleDataType,
36 typename AElementwiseOperation,
37 typename BElementwiseOperation,
38 typename CElementwiseOperation,
39 GemmSpecialization GemmSpec,
40 ck::index_t NumPrefetch,
41 ck::index_t BlockSize,
42 ck::index_t MPerBlock,
43 ck::index_t NPerBlock,
44 ck::index_t KPerBlock,
45 ck::index_t K1,
46 ck::index_t MPerWmma,
47 ck::index_t NPerWmma,
48 ck::index_t MRepeat,
49 ck::index_t NRepeat,
50 typename ABlockTransferThreadClusterLengths_K0_M_K1,
51 typename ABlockTransferThreadClusterArrangeOrder,
52 typename ABlockTransferSrcAccessOrder,
53 ck::index_t ABlockTransferSrcVectorDim,
54 ck::index_t ABlockTransferSrcScalarPerVector,
55 ck::index_t ABlockTransferDstScalarPerVector_K1,
56 bool ABlockLdsAddExtraM,
57 typename BBlockTransferThreadClusterLengths_K0_N_K1,
58 typename BBlockTransferThreadClusterArrangeOrder,
59 typename BBlockTransferSrcAccessOrder,
60 ck::index_t BBlockTransferSrcVectorDim,
61 ck::index_t BBlockTransferSrcScalarPerVector,
62 ck::index_t BBlockTransferDstScalarPerVector_K1,
63 bool BBlockLdsAddExtraN,
64 index_t CShuffleMRepeatPerShuffle,
65 index_t CShuffleNRepeatPerShuffle,
66 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
67 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
71 BLayout,
72 CLayout,
73 ADataType,
74 BDataType,
75 CDataType,
76 AElementwiseOperation,
77 BElementwiseOperation,
78 CElementwiseOperation>
79{
80 static constexpr auto I0 = Number<0>{};
81 static constexpr auto I1 = Number<1>{};
82 static constexpr auto I2 = Number<2>{};
83 static constexpr auto I3 = Number<3>{};
84 static constexpr auto I4 = Number<4>{};
85 static constexpr auto I5 = Number<5>{};
86 static constexpr auto I6 = Number<6>{};
87 // K1 = Max Vector Access Pixels
88 static constexpr auto K1Number = Number<K1>{};
89
90 static constexpr auto MWaves = MPerBlock / (MRepeat * MPerWmma);
91 static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
92 static constexpr auto WmmaK = K1 == 16 ? 32 : 16;
93
94 static constexpr auto AEnableLds_auto =
96 static constexpr auto BEnableLds_auto =
98
99 // If true, LDS is used unconditionally
100 // LDS bypass feature not implemented for dequantization pipeline.
101 static constexpr auto AEnableLds_manu = true;
102 static constexpr auto BEnableLds_manu = true;
103
104 static constexpr auto AEnableLds = AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
105 static constexpr auto BEnableLds = BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
106
107 static constexpr auto matrix_padder =
108 MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
109
111
112 // Describe how data read from Global memory
113 static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
114 {
115 const auto a_grid_desc_m_k = [&]() {
117 {
118 const auto a_grid_desc_mraw_kraw =
120
121 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
122 }
124 {
125 const auto a_grid_desc_mraw_kraw =
127
128 return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
129 }
130 }();
131
132 const auto M = a_grid_desc_m_k.GetLength(I0);
133 const auto K = a_grid_desc_m_k.GetLength(I1);
134 assert(K % K1 == 0);
135
136 if constexpr(AEnableLds)
137 {
138 const index_t K0 = K / K1;
139
141 a_grid_desc_m_k,
146 }
147 else
148 {
149 constexpr auto A_KRow = 2;
150 constexpr auto A_K0PerWmma = WmmaK / A_KRow / K1Number;
151 const auto A_KWmma = K / WmmaK;
152
153 const auto M0 = M / MPerBlock;
154 // 0 1 0 1 2 3 4 5 6
155 // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
157 a_grid_desc_m_k,
161 make_tuple(M0 * MRepeat, Number<MWaves>{}, Number<MPerWmma>{}))),
164 }
165 }
166
167 static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
168 {
169 const auto b_grid_desc_n_k = [&]() {
171 {
172 const auto b_grid_desc_nraw_kraw =
174
175 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
176 }
178 {
179 const auto b_grid_desc_nraw_kraw =
181
182 return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
183 }
184 }();
185
186 const auto N = b_grid_desc_n_k.GetLength(I0);
187 const auto K = b_grid_desc_n_k.GetLength(I1);
188 assert(K % K1 == 0);
189
190 if constexpr(BEnableLds)
191 {
192 const index_t K0 = K / K1;
193
195 b_grid_desc_n_k,
200 }
201 else
202 {
203 constexpr auto B_KRow = 2;
204 constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
205 const auto B_KWmma = K / WmmaK;
206
207 const auto N0 = N / NPerBlock;
208 // 0 1 0 1 2 3 4 5 6
209 // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
211 b_grid_desc_n_k,
215 make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
218 }
219 }
220
221 static auto MakeScaleGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB = 0)
222 {
223 // assume Scale is [1, N]
224 const auto scale_grid_desc_n_k = [&]() {
225 const auto scale_grid_desc_nraw_kraw =
227
228 return matrix_padder.PadBDescriptor_N_K(scale_grid_desc_nraw_kraw);
229 }();
230
231 const auto N = scale_grid_desc_n_k.GetLength(I0);
232 const auto K = scale_grid_desc_n_k.GetLength(I1);
233 // When K = 1, it might be scale tensor.
234 assert(K % K1 == 0 && K != 1);
235
236 if constexpr(BEnableLds)
237 {
238 const index_t K0 = K / K1;
239
241 scale_grid_desc_n_k,
242 make_tuple(make_unmerge_transform(make_tuple(K0, 1)), // Reduce K1 = 1
246 }
247 else
248 {
249 constexpr auto B_KRow = 2;
250 constexpr auto B_K0PerWmma = WmmaK / B_KRow / K1Number;
251 const auto B_KWmma = K / WmmaK;
252
253 const auto N0 = N / NPerBlock;
254 // 0 1 0 1 2 3 4 5 6
255 // M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
257 scale_grid_desc_n_k,
261 make_tuple(N0 * NRepeat, Number<NWaves>{}, Number<NPerWmma>{}))),
264 }
265 }
266
267 static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
268 {
269 const auto c_grid_desc_mraw_nraw = [&]() {
271 {
272 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
273 make_tuple(StrideC, I1));
274 }
276 {
277 return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
278 make_tuple(I1, StrideC));
279 }
280 }();
281
282 return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
283 }
284
285 // Gridwise descriptor, mapping to whole given provblem.
286 using AGridDesc = decltype(MakeAGridDescriptor(1, 1, 1));
287 using BGridDesc = decltype(MakeBGridDescriptor(1, 1, 1));
288 using ScaleGridDesc = decltype(MakeScaleGridDescriptor(1, 1, 0));
289 using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
290
291 // GridwiseGemm
293 BlockSize,
294 ADataType,
295 BDataType,
296 ScaleDataType,
297 AccDataType,
298 CShuffleDataType,
299 CDataType,
301 AGridDesc,
302 BGridDesc,
305 AElementwiseOperation,
306 BElementwiseOperation,
307 CElementwiseOperation,
308 MPerBlock,
309 NPerBlock,
310 KPerBlock,
311 MPerWmma,
312 NPerWmma,
313 K1,
314 MRepeat,
315 NRepeat,
316 ABlockTransferThreadClusterLengths_K0_M_K1,
317 ABlockTransferThreadClusterArrangeOrder,
318 ABlockTransferSrcAccessOrder,
319 ABlockTransferSrcVectorDim,
320 ABlockTransferSrcScalarPerVector,
321 ABlockTransferDstScalarPerVector_K1,
322 false, // AThreadTransferSrcResetCoordinateAfterRun,
324 ABlockLdsAddExtraM,
325 BBlockTransferThreadClusterLengths_K0_N_K1,
326 BBlockTransferThreadClusterArrangeOrder,
327 BBlockTransferSrcAccessOrder,
328 BBlockTransferSrcVectorDim,
329 BBlockTransferSrcScalarPerVector,
330 BBlockTransferDstScalarPerVector_K1,
331 false, // BThreadTransferSrcResetCoordinateAfterRun,
333 BBlockLdsAddExtraN,
334 CShuffleMRepeatPerShuffle,
335 CShuffleNRepeatPerShuffle,
336 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
337 CShuffleBlockTransferScalarPerVector_NPerBlock,
338 NumPrefetch,
339 LoopSched,
340 PipelineVer>;
341
342 // Argument
343 struct Argument : public BaseArgument
344 {
345 Argument(const ADataType* p_a_grid,
346 const BDataType* p_b_grid,
347 const ScaleDataType* p_scale_grid,
348 CDataType* p_c_grid,
349 index_t M,
350 index_t N,
351 index_t K,
352 index_t StrideA,
353 index_t StrideB,
354 index_t StrideC,
355 index_t M01,
356 index_t N01,
357 AElementwiseOperation a_element_op,
358 BElementwiseOperation b_element_op,
359 CElementwiseOperation c_element_op)
360 : p_a_grid_{p_a_grid},
361 p_b_grid_{p_b_grid},
362 p_scale_grid_{p_scale_grid},
363 p_c_grid_{p_c_grid},
364 a_grid_desc_{},
365 b_grid_desc_{},
370 M01_{M01},
371 N01_{N01},
372 a_element_op_{a_element_op},
373 b_element_op_{b_element_op},
374 c_element_op_{c_element_op},
375 MRaw_{M},
376 NRaw_{N},
377 KRaw_{K}
378 {
383
386
389 {
393 }
394 }
395
396 // private:
397 const ADataType* p_a_grid_;
398 const BDataType* p_b_grid_;
399 const ScaleDataType* p_scale_grid_;
400 CDataType* p_c_grid_;
410 AElementwiseOperation a_element_op_;
411 BElementwiseOperation b_element_op_;
412 CElementwiseOperation c_element_op_;
413 // for checking vector load/store
417 };
418
419 // Invoker
420 struct Invoker : public BaseInvoker
421 {
423
424 float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
425 {
427 arg.b_grid_desc_,
430 {
431 throw std::runtime_error(
432 "wrong! GridwiseGemm_k0mk1_k0nk1_m0nm1_wmma_v1r1 has invalid setting");
433 }
434
435 const index_t grid_size =
436 arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
437
438 const auto K = [&]() {
439 if constexpr(AEnableLds)
440 {
441 return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I2);
442 }
443 else
444 {
445 return arg.a_grid_desc_.GetLength(I0) * arg.a_grid_desc_.GetLength(I3) *
446 arg.a_grid_desc_.GetLength(I4) * arg.a_grid_desc_.GetLength(I6);
447 }
448 }();
449 auto launch_kernel = [&](auto has_main_k_block_loop) {
450 const auto kernel = kernel_fpAintB_gemm_wmma<
451 GridwiseGemm,
452 ADataType,
453 BDataType,
454 ScaleDataType,
455 CDataType,
461 AElementwiseOperation,
462 BElementwiseOperation,
463 CElementwiseOperation,
465 has_main_k_block_loop>;
466
467 return launch_and_time_kernel(stream_config,
468 kernel,
469 dim3(grid_size),
470 dim3(BlockSize),
471 0,
472 arg.p_a_grid_,
473 arg.p_b_grid_,
474 arg.p_scale_grid_,
475 arg.p_c_grid_,
476 arg.a_grid_desc_,
477 arg.b_grid_desc_,
480 arg.a_element_op_,
481 arg.b_element_op_,
482 arg.c_element_op_,
484 };
485
487 {
488 return launch_kernel(integral_constant<bool, true>{});
489 }
490 else
491 {
492 return launch_kernel(integral_constant<bool, false>{});
493 }
494 }
495
496 // polymorphic
497 float Run(const BaseArgument* p_arg,
498 const StreamConfig& stream_config = StreamConfig{}) override
499 {
500 return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
501 }
502 };
503
504 static constexpr bool IsValidCompilationParameter()
505 {
506 // TODO: properly implement this check
507 return true;
508 }
509
510 static bool IsSupportedArgument(const Argument& arg)
511 {
513 {
516 {
517 printf("DeviceOp err: AccDataType");
518 return false;
519 }
520 }
521 else
522 {
523 printf("DeviceOp err: Arch");
524 return false;
525 }
526
527 // check vector load/store
528 {
531
532 // check vector load of A
533 if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
534 {
535 if(arg.KRaw_ % ABlockTransferSrcScalarPerVector != 0)
536 {
537 return false;
538 }
539 }
540 else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
541 {
542 // FIXME: not rigorous
543 if(arg.MRaw_ % ABlockTransferSrcScalarPerVector != 0)
544 {
545 return false;
546 }
547 }
548 else
549 {
550 return false;
551 }
552
553 // check vector laod of B
554 if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
555 {
556 if(arg.KRaw_ % BBlockTransferSrcScalarPerVector != 0)
557 {
558 return false;
559 }
560 }
561 else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
562 {
563 // FIXME: not rigorous
564 if(arg.NRaw_ % BBlockTransferSrcScalarPerVector != 0)
565 {
566 return false;
567 }
568 }
569 else
570 {
571 return false;
572 }
573
574 // check vector store of C
575 // only support RowMajor for now
576 if constexpr(is_same_v<CLayout, Row>)
577 {
578 if(arg.NRaw_ % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
579 {
580 return false;
581 }
582 }
583 else
584 {
585 return false;
586 }
587 }
588
591 }
592
593 // polymorphic
594 bool IsSupportedArgument(const BaseArgument* p_arg) override
595 {
596 return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
597 }
598
599 static auto MakeArgument(const ADataType* p_a,
600 const BDataType* p_b,
601 const ScaleDataType* p_scale,
602 CDataType* p_c,
603 index_t M,
604 index_t N,
605 index_t K,
606 index_t StrideA,
607 index_t StrideB,
608 index_t StrideC,
609 AElementwiseOperation a_element_op,
610 BElementwiseOperation b_element_op,
611 CElementwiseOperation c_element_op)
612 {
613 return Argument{p_a,
614 p_b,
615 p_scale,
616 p_c,
617 M,
618 N,
619 K,
620 StrideA,
621 StrideB,
622 StrideC,
623 1,
624 1,
625 a_element_op,
626 b_element_op,
627 c_element_op};
628 }
629
630 static auto MakeInvoker() { return Invoker{}; }
631
632 // polymorphic
633 std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
634 const void* p_b,
635 const void* p_scale,
636 void* p_c,
637 index_t M,
638 index_t N,
639 index_t K,
640 index_t StrideA,
641 index_t StrideB,
642 index_t StrideC,
643 AElementwiseOperation a_element_op,
644 BElementwiseOperation b_element_op,
645 CElementwiseOperation c_element_op) override
646 {
647 return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
648 static_cast<const BDataType*>(p_b),
649 static_cast<const ScaleDataType*>(p_scale),
650 static_cast<CDataType*>(p_c),
651 M,
652 N,
653 K,
654 StrideA,
655 StrideB,
656 StrideC,
657 1,
658 1,
659 a_element_op,
660 b_element_op,
661 c_element_op);
662 }
663
664 // polymorphic
665 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
666 {
667 return std::make_unique<Invoker>(Invoker{});
668 }
669
670 // polymorphic
671 std::string GetTypeString() const override
672 {
673 auto str = std::stringstream();
674
675 std::map<LoopScheduler, std::string> LoopSchedToString{
676 {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
677
678 std::map<PipelineVersion, std::string> PipelineVersionToString{
679 {PipelineVersion::v1, "v1"},
680 {PipelineVersion::v2, "v2"},
681 {PipelineVersion::weight_only, "weight_only"}};
682
683 // clang-format off
684 str << "DeviceFpAintBGemm_Wmma_CShuffle"
685 << "<"
686 << BlockSize << ", "
687 << MPerBlock << ", "
688 << NPerBlock << ", "
689 << KPerBlock << ", "
690 << K1 << ", "
691 << MPerWmma << ", "
692 << NPerWmma << ", "
693 << MRepeat << ", "
694 << NRepeat
695 << ">"
696 << " AEnableLds: "
697 << AEnableLds << ", "
698 << "BEnableLds: "
699 << BEnableLds << ", "
700 << "NumPrefetch: "
701 << NumPrefetch << ", "
702 << "LoopScheduler: "
703 << LoopSchedToString[LoopSched] << ", "
704 << "PipelineVersion: "
705 << PipelineVersionToString[PipelineVer];
706 // clang-format on
707
708 return str.str();
709 }
710};
711
712} // namespace device
713} // namespace tensor_operation
714} // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:14
Definition convolution_backward_data_specialization.hpp:8
GemmSpecialization
Definition gemm_specialization.hpp:11
Definition convolution_backward_data_specialization.hpp:7
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition tile/host/kernel_launch.hpp:173
Definition ck.hpp:268
__host__ __device__ constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition multi_index_transform_helper.hpp:12
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition tensor_descriptor_helper.hpp:49
@ Set
Definition ck.hpp:278
integral_constant< index_t, N > Number
Definition number.hpp:12
bool is_gfx12_supported()
Definition host_utility/device_prop.hpp:55
constexpr bool is_same_v
Definition type.hpp:283
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
LoopScheduler
Definition loop_scheduler.hpp:15
@ Default
Definition loop_scheduler.hpp:16
@ Interwave
Definition loop_scheduler.hpp:17
__host__ __device__ constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition tensor_description/tensor_descriptor.hpp:319
PipelineVersion
Definition gridwise_gemm_pipeline_selector.hpp:18
@ v2
Definition gridwise_gemm_pipeline_selector.hpp:20
@ v1
Definition gridwise_gemm_pipeline_selector.hpp:19
@ weight_only
Definition gridwise_gemm_pipeline_selector.hpp:23
__host__ __device__ constexpr auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition multi_index_transform_helper.hpp:90
bool is_gfx11_supported()
Definition host_utility/device_prop.hpp:60
__global__ void kernel_fpAintB_gemm_wmma(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, const ScaleDataType *__restrict__ p_scale_grid, CDataType *__restrict__ p_c_grid, const AGridDesc a_grid_desc, const BGridDesc b_grid_desc, const ScaleGridDesc scale_grid_desc, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition gridwise_fpAintB_gemm_wmma.hpp:40
constexpr LoopScheduler make_default_loop_scheduler()
Definition loop_scheduler.hpp:20
Definition ck/stream_config.hpp:10
Definition gridwise_fpAintB_gemm_wmma.hpp:136
Definition utility/sequence.hpp:43
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition tensor_operation/gpu/device/tensor_layout.hpp:31
Definition tensor_operation/gpu/device/tensor_layout.hpp:26
Definition device_base.hpp:197
GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_
Definition device_fpAintB_gemm_wmma.hpp:407
index_t M01_
Definition device_fpAintB_gemm_wmma.hpp:408
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, const ScaleDataType *p_scale_grid, CDataType *p_c_grid, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t M01, index_t N01, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_fpAintB_gemm_wmma.hpp:345
index_t KRaw_
Definition device_fpAintB_gemm_wmma.hpp:416
AGridDesc a_grid_desc_
Definition device_fpAintB_gemm_wmma.hpp:401
BGridDesc b_grid_desc_
Definition device_fpAintB_gemm_wmma.hpp:402
CElementwiseOperation c_element_op_
Definition device_fpAintB_gemm_wmma.hpp:412
index_t MRaw_
Definition device_fpAintB_gemm_wmma.hpp:414
const BDataType * p_b_grid_
Definition device_fpAintB_gemm_wmma.hpp:398
CDataType * p_c_grid_
Definition device_fpAintB_gemm_wmma.hpp:400
index_t N01_
Definition device_fpAintB_gemm_wmma.hpp:409
index_t NRaw_
Definition device_fpAintB_gemm_wmma.hpp:415
CGridDesc_M_N c_grid_desc_m_n_
Definition device_fpAintB_gemm_wmma.hpp:404
ScaleGridDesc scale_grid_desc_
Definition device_fpAintB_gemm_wmma.hpp:403
const ADataType * p_a_grid_
Definition device_fpAintB_gemm_wmma.hpp:397
BElementwiseOperation b_element_op_
Definition device_fpAintB_gemm_wmma.hpp:411
AElementwiseOperation a_element_op_
Definition device_fpAintB_gemm_wmma.hpp:410
const ScaleDataType * p_scale_grid_
Definition device_fpAintB_gemm_wmma.hpp:399
GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock
Definition device_fpAintB_gemm_wmma.hpp:406
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition device_fpAintB_gemm_wmma.hpp:497
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition device_fpAintB_gemm_wmma.hpp:424
DeviceOp::Argument Argument
Definition device_fpAintB_gemm_wmma.hpp:422
Definition device_fpAintB_gemm_wmma.hpp:79
static constexpr auto AEnableLds_manu
Definition device_fpAintB_gemm_wmma.hpp:101
static constexpr auto I5
Definition device_fpAintB_gemm_wmma.hpp:85
static constexpr auto AEnableLds
Definition device_fpAintB_gemm_wmma.hpp:104
decltype(MakeBGridDescriptor(1, 1, 1)) BGridDesc
Definition device_fpAintB_gemm_wmma.hpp:287
static constexpr auto I6
Definition device_fpAintB_gemm_wmma.hpp:86
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const void *p_scale, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition device_fpAintB_gemm_wmma.hpp:633
std::string GetTypeString() const override
Definition device_fpAintB_gemm_wmma.hpp:671
decltype(MakeScaleGridDescriptor(1, 1, 0)) ScaleGridDesc
Definition device_fpAintB_gemm_wmma.hpp:288
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition device_fpAintB_gemm_wmma.hpp:289
static auto MakeScaleGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB=0)
Definition device_fpAintB_gemm_wmma.hpp:221
static constexpr auto I1
Definition device_fpAintB_gemm_wmma.hpp:81
DeviceFpAintBGemm_Wmma_CShuffle DeviceOp
Definition device_fpAintB_gemm_wmma.hpp:110
static constexpr auto BEnableLds_auto
Definition device_fpAintB_gemm_wmma.hpp:96
static constexpr auto BEnableLds_manu
Definition device_fpAintB_gemm_wmma.hpp:102
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_fpAintB_gemm_wmma.hpp:665
decltype(MakeAGridDescriptor(1, 1, 1)) AGridDesc
Definition device_fpAintB_gemm_wmma.hpp:286
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
Definition device_fpAintB_gemm_wmma.hpp:267
static bool IsSupportedArgument(const Argument &arg)
Definition device_fpAintB_gemm_wmma.hpp:510
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
Definition device_fpAintB_gemm_wmma.hpp:113
static constexpr auto I0
Definition device_fpAintB_gemm_wmma.hpp:80
static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
Definition device_fpAintB_gemm_wmma.hpp:167
static constexpr auto I3
Definition device_fpAintB_gemm_wmma.hpp:83
static constexpr auto AEnableLds_auto
Definition device_fpAintB_gemm_wmma.hpp:94
static constexpr auto I4
Definition device_fpAintB_gemm_wmma.hpp:84
static constexpr auto NWaves
Definition device_fpAintB_gemm_wmma.hpp:91
static constexpr auto matrix_padder
Definition device_fpAintB_gemm_wmma.hpp:107
static constexpr auto MWaves
Definition device_fpAintB_gemm_wmma.hpp:90
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_fpAintB_gemm_wmma.hpp:594
static constexpr auto K1Number
Definition device_fpAintB_gemm_wmma.hpp:88
GridwiseFpAintBGemm_Wmma< BlockSize, ADataType, BDataType, ScaleDataType, AccDataType, CShuffleDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc, BGridDesc, ScaleGridDesc, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer > GridwiseGemm
Definition device_fpAintB_gemm_wmma.hpp:292
static constexpr auto WmmaK
Definition device_fpAintB_gemm_wmma.hpp:92
static constexpr auto BEnableLds
Definition device_fpAintB_gemm_wmma.hpp:105
static constexpr auto I2
Definition device_fpAintB_gemm_wmma.hpp:82
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, const ScaleDataType *p_scale, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition device_fpAintB_gemm_wmma.hpp:599
static auto MakeInvoker()
Definition device_fpAintB_gemm_wmma.hpp:630
static constexpr bool IsValidCompilationParameter()
Definition device_fpAintB_gemm_wmma.hpp:504
Definition device_gemm_dequantB.hpp:25
Definition matrix_padder.hpp:180