epilogue_cshuffle_v3_wmma.hpp Source File

epilogue_cshuffle_v3_wmma.hpp Source File#

Composable Kernel: epilogue_cshuffle_v3_wmma.hpp Source File
epilogue_cshuffle_v3_wmma.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10template <typename DsDataType,
11 typename EDataType,
12 typename AccDataType,
13 typename CShuffleDataType,
14 index_t MPerBlock,
15 index_t NPerBlock,
16 index_t MPerWmma,
17 index_t NPerWmma,
18 index_t MRepeat,
19 index_t NRepeat,
20 index_t CShuffleMRepeatPerShuffle,
21 index_t CShuffleNRepeatPerShuffle,
22 typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
23 typename CDEShuffleBlockTransferScalarPerVectors,
24 typename CDEElementwiseOperation,
25 typename ThisThreadBlock,
26 typename BlockwiseGemmPipe>
28 : EpilogueCShuffleBase<DsDataType,
29 EDataType,
30 AccDataType,
31 CShuffleDataType,
32 MPerBlock,
33 NPerBlock,
34 MPerWmma,
35 NPerWmma,
36 MRepeat,
37 NRepeat,
38 CShuffleMRepeatPerShuffle,
39 CShuffleNRepeatPerShuffle,
40 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
41 CDEShuffleBlockTransferScalarPerVectors,
42 CDEElementwiseOperation,
43 ThisThreadBlock,
44 BlockwiseGemmPipe>
45{
47 DsDataType,
48 EDataType,
49 AccDataType,
50 CShuffleDataType,
51 MPerBlock,
52 NPerBlock,
53 MPerWmma,
54 NPerWmma,
55 MRepeat,
56 NRepeat,
57 CShuffleMRepeatPerShuffle,
58 CShuffleNRepeatPerShuffle,
59 CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
60 CDEShuffleBlockTransferScalarPerVectors,
61 CDEElementwiseOperation,
63 BlockwiseGemmPipe>;
64
68 using Base::I1;
69 using Base::NumDTensor;
70
71 template <InMemoryDataOperationEnum EGlobalMemoryDataOperation,
72 typename CThreadBuf,
73 typename DsGridPointer,
74 typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
75 typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
76 __device__ static void Run(CThreadBuf& c_thread_buf,
77 DsGridPointer p_ds_grid,
78 EDataType* p_e_grid,
79 void* p_shared,
80 const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
81 ds_grid_desc_mblock_mperblock_nblock_nperblock,
82 const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
83 e_grid_desc_mblock_mperblock_nblock_nperblock,
84 CDEElementwiseOperation& cde_element_op,
85 const index_t& block_m_id,
86 const index_t& block_n_id)
87 {
88 const auto ds_grid_buf = generate_tuple(
89 [&](auto i) {
91 p_ds_grid[i],
92 ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
93 },
95
97 p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
98
99 // C mapping in single thread.
100 constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
101 BlockwiseGemmPipe::
102 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
103
104 // LDS buffer
105 constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
107
108 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
109 static_cast<CShuffleDataType*>(p_shared),
110 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
111 .GetElementSpaceSize());
112
113 // Thread transfer Vgpr to LDS
114 auto c_thread_copy_vgpr_to_lds = GetVgprToLDSEpilogueDescriptor();
115
116 // Space Filling Curve Vgpr
117 constexpr auto sfc_c_vgpr = typename Base::SpaceFillingCurveVgpr{};
118
119 // Space Filling Curve Vmem
120 constexpr auto sfc_cde_global = typename Base::SpaceFillingCurveVmem{};
121
122 // Block descriptor
123 constexpr auto
124 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
126
127 // tuple of reference to C/Ds tensor descriptors
128 const auto c_ds_desc_refs = concat_tuple_of_reference(
129 tie(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
130 generate_tie([&](auto i) -> const auto& // return type should be reference
131 { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
133
134 // Thread transfer LDS to Vmem
135 auto cde_shuffle_block_copy_lds_and_global =
137 c_ds_desc_refs,
138 e_grid_desc_mblock_mperblock_nblock_nperblock,
139 cde_element_op,
140 block_m_id,
141 block_n_id);
142
143 // tuple of reference to C/Ds tensor buffers
144 const auto c_ds_buf_refs = concat_tuple_of_reference(
145 tie(c_shuffle_block_buf),
146 generate_tie([&](auto i) -> const auto& // return type should be reference
147 { return ds_grid_buf[i]; },
149
150 constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
151
152 static_assert(num_access == sfc_cde_global.GetNumOfAccess(), "wrong!");
153
154 // CShuffle and Store
155 static_for<0, num_access, 1>{}([&](auto access_id) {
156 // make sure it's safe to write to LDS
158
159 // each thread write its data from VGPR to LDS
160 c_thread_copy_vgpr_to_lds.Run(
161 c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
162 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
163 c_thread_buf,
164 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
165 c_shuffle_block_buf);
166
167 // make sure it's safe to read from LDS
169
170 // each block loads its C data from LDS, D from global, applies elementwise
171 // operation and stores result E to global
172 cde_shuffle_block_copy_lds_and_global.Run(
173 c_ds_desc_refs,
174 c_ds_buf_refs,
175 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
176 tie(e_grid_buf));
177
178 if constexpr(access_id < num_access - 1)
179 {
180 constexpr auto cde_global_step = sfc_cde_global.GetForwardStep(access_id);
181 // move on Ds
182 static_for<0, NumDTensor, 1>{}([&](auto i) {
183 cde_shuffle_block_copy_lds_and_global.MoveSrcSliceWindow(
184 c_ds_desc_refs, i + I1, cde_global_step);
185 });
186
187 // move on E
188 cde_shuffle_block_copy_lds_and_global.MoveDstSliceWindow(
189 tie(e_grid_desc_mblock_mperblock_nblock_nperblock), cde_global_step);
190 }
191 });
192 }
193};
194
195} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition utility/tuple.hpp:218
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
__host__ __device__ constexpr auto generate_tie(F &&f, Number< N >)
Definition tuple_helper.hpp:34
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition tuple_helper.hpp:42
Definition epilogue_cshuffle_v3_wmma_base.hpp:29
static constexpr index_t NumDTensor
Definition epilogue_cshuffle_v3_wmma_base.hpp:38
static __device__ auto GetLDSToVmemEpilogueDescriptor(CDsDescRefs &c_ds_desc_refs, EGridDesc &e_grid_desc_mblock_mperblock_nblock_nperblock, CDEElementwiseOperation &cde_element_op, const index_t &block_m_id, const index_t &block_n_id)
Definition epilogue_cshuffle_v3_wmma_base.hpp:204
SpaceFillingCurve< Sequence< MRepeat, 1, 1, NRepeat, 1, 1, BlockwiseGemmPipe::MAccVgprs >, Sequence< 0, 1, 2, 3, 4, 5, 6 >, Sequence< CShuffleMRepeatPerShuffle, 1, 1, CShuffleNRepeatPerShuffle, 1, 1, BlockwiseGemmPipe::MAccVgprs > > SpaceFillingCurveVgpr
Definition epilogue_cshuffle_v3_wmma_base.hpp:42
static __device__ constexpr auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition epilogue_cshuffle_v3_wmma_base.hpp:63
static constexpr auto I1
Definition epilogue_cshuffle_v3_wmma_base.hpp:31
static __device__ auto GetVgprToLDSEpilogueDescriptor()
Definition epilogue_cshuffle_v3_wmma_base.hpp:118
SpaceFillingCurve< Sequence< 1, MPerBlock, 1, NPerBlock >, Sequence< 0, 2, 1, 3 >, Sequence< 1, CShuffleMRepeatPerShuffle *BlockwiseGemmPipe::MWaves *MPerWmma, 1, CShuffleNRepeatPerShuffle *BlockwiseGemmPipe::NWaves *NPerWmma > > SpaceFillingCurveVmem
Definition epilogue_cshuffle_v3_wmma_base.hpp:53
static __device__ constexpr auto GetCShuffleLDSDescriptor()
Definition epilogue_cshuffle_v3_wmma_base.hpp:78
Definition epilogue_cshuffle_v3_wmma.hpp:45
static __device__ constexpr auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition epilogue_cshuffle_v3_wmma_base.hpp:63
static constexpr auto I1
Definition epilogue_cshuffle_v3_wmma_base.hpp:31
static __device__ auto GetVgprToLDSEpilogueDescriptor()
Definition epilogue_cshuffle_v3_wmma_base.hpp:118
EpilogueCShuffleBase< DsDataType, EDataType, AccDataType, CShuffleDataType, MPerBlock, NPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, CDEElementwiseOperation, ThisThreadBlock, BlockwiseGemmPipe > Base
Definition epilogue_cshuffle_v3_wmma.hpp:46
static __device__ constexpr auto GetCShuffleLDSDescriptor()
Definition epilogue_cshuffle_v3_wmma_base.hpp:78
static __device__ void Run(CThreadBuf &c_thread_buf, DsGridPointer p_ds_grid, EDataType *p_e_grid, void *p_shared, const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, CDEElementwiseOperation &cde_element_op, const index_t &block_m_id, const index_t &block_n_id)
Definition epilogue_cshuffle_v3_wmma.hpp:76
Definition thread_group.hpp:12
Definition functional2.hpp:33