device_gemm_multiple_d.hpp Source File

device_gemm_multiple_d.hpp Source File#

Composable Kernel: device_gemm_multiple_d.hpp Source File
device_gemm_multiple_d.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5#ifndef __HIPCC_RTC__
6#include <array>
7#endif
8
11
12namespace ck {
13namespace tensor_operation {
14namespace device {
15
16// GEMM:
17// input : A[M, K], B[K, N],
18// input : D0[M, N], D1[M, N], ...
19// output : E[M, N]
20// C = a_op(A) * b_op(B)
21// E = cde_op(C, D0, D1, ...)
22// Assume:
23// D0, D1, ... and E have the same layout
24template <typename ALayout,
25 typename BLayout,
26 typename DsLayout,
27 typename ELayout,
28 typename ADataType,
29 typename BDataType,
30 typename DsDataType,
31 typename EDataType,
32 typename AElementwiseOperation,
33 typename BElementwiseOperation,
34 typename CDEElementwiseOperation>
36{
37 static constexpr index_t NumDTensor = DsDataType::Size();
38
39#ifndef __HIPCC_RTC__
40 virtual std::unique_ptr<BaseArgument>
41 MakeArgumentPointer(const void* p_a,
42 const void* p_b,
43 std::array<const void*, NumDTensor> p_ds,
44 void* p_e,
48 ck::index_t StrideA,
49 ck::index_t StrideB,
50 std::array<ck::index_t, NumDTensor> StrideDs,
51 ck::index_t StrideE,
52 AElementwiseOperation a_element_op,
53 BElementwiseOperation b_element_op,
54 CDEElementwiseOperation cde_element_op) = 0;
55
56 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
57#endif
58};
59
60// GEMM:
61// input : A[M, K], B[K, N],
62// input : D0[M, N], D1[M, N], ...
63// output : E[M, N]
64// C = a_op(A) * b_op(B)
65// E = cde_op(C, D0, D1, ...)
66// Assume:
67// D0, D1, ... and E have the same layout
68template <typename ALayout,
69 typename BLayout,
70 typename DsLayout,
71 typename ELayout,
72 typename ADataType,
73 typename BDataType,
74 typename DsDataType,
75 typename EDataType,
76 typename AElementwiseOperation,
77 typename BElementwiseOperation,
78 typename CDEElementwiseOperation>
80{
81 static constexpr index_t NumDTensor = DsDataType::Size();
82
83#ifndef __HIPCC_RTC__
84 virtual std::unique_ptr<BaseArgument>
85 MakeArgumentPointer(const void* p_a,
86 const void* p_b,
87 std::array<const void*, NumDTensor> p_ds,
88 void* p_e,
92 ck::index_t StrideA,
93 ck::index_t StrideB,
94 std::array<ck::index_t, NumDTensor> StrideDs,
95 ck::index_t StrideE,
96 ck::index_t KBatch,
97 AElementwiseOperation a_element_op,
98 BElementwiseOperation b_element_op,
99 CDEElementwiseOperation cde_element_op) = 0;
100
101 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
102#endif
103};
104
105// GEMM:
106// input : A[M, K], B[K, N],
107// input : D0[M, N], D1[M, N], ...
108// output : E[M, N]
109// C = a_op(A) * b_op(B)
110// E = cde_op(C, D0, D1, ...)
111// Assume:
112// D0, D1, ... and E have the same layout
113template <typename ALayout,
114 typename BLayout,
115 typename DsLayout,
116 typename ELayout,
117 typename ADataType,
118 typename BDataType,
119 typename DsDataType,
120 typename EDataType,
121 typename AElementwiseOperation,
122 typename BElementwiseOperation,
123 typename CDEElementwiseOperation>
125{
126 static constexpr index_t NumDTensor = DsDataType::Size();
127
128#ifndef CK_CODE_GEN_RTC
129 virtual std::unique_ptr<BaseArgument>
130 MakeArgumentPointer(const void* p_a,
131 const void* p_b,
132 std::array<const void*, NumDTensor> p_ds,
133 void* p_e,
134 ck::index_t M,
135 ck::index_t N,
136 ck::index_t K,
137 ck::index_t StrideA,
138 ck::index_t StrideB,
139 std::array<ck::index_t, NumDTensor> StrideDs,
140 ck::index_t StrideE,
141 ck::index_t KBatch,
142 AElementwiseOperation a_element_op,
143 BElementwiseOperation b_element_op,
144 CDEElementwiseOperation cde_element_op) = 0;
145
146 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
147
148 virtual int GetPreShuffleParameters() = 0;
149#endif
150};
151
152template <typename ALayout,
153 typename BLayout,
154 typename DsLayout,
155 typename ELayout,
156 typename ADataType,
157 typename AScaleDataType,
158 typename BDataType,
159 typename BScaleDataType,
160 typename DsDataType,
161 typename EDataType,
162 index_t ScaleBlockSize,
163 typename AElementwiseOperation,
164 typename BElementwiseOperation,
165 typename CDEElementwiseOperation>
167{
168 static constexpr index_t NumDTensor = DsDataType::Size();
169
170#ifndef CK_CODE_GEN_RTC
171 virtual std::unique_ptr<BaseArgument>
172 MakeArgumentPointer(const void* p_a,
173 const void* p_a_scale,
174 const void* p_b,
175 const void* p_b_scale,
176 std::array<const void*, NumDTensor> p_ds,
177 void* p_e,
178 ck::index_t M,
179 ck::index_t N,
180 ck::index_t K,
181 ck::index_t StrideA,
182 ck::index_t StrideAScale,
183 ck::index_t StrideB,
184 ck::index_t StrideBScale,
185 std::array<ck::index_t, NumDTensor> StrideDs,
186 ck::index_t StrideE,
187 ck::index_t KBatch,
188 AElementwiseOperation a_element_op,
189 BElementwiseOperation b_element_op,
190 CDEElementwiseOperation cde_element_op) = 0;
191
192 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
193
194 virtual int GetPreShuffleParameters() = 0;
195#endif
196};
197
205template <typename ALayout,
206 typename BLayout,
207 typename DsLayout,
208 typename ELayout,
209 typename ADataType,
210 typename BDataType,
211 typename DsDataType,
212 typename EDataType,
213 typename AElementwiseOperation,
214 typename BElementwiseOperation,
215 typename CDEElementwiseOperation>
217 BLayout,
218 DsLayout,
219 ELayout,
220 ADataType,
221 BDataType,
222 DsDataType,
223 EDataType,
224 AElementwiseOperation,
225 BElementwiseOperation,
226 CDEElementwiseOperation>
227{
229 BLayout,
230 DsLayout,
231 ELayout,
232 ADataType,
233 BDataType,
234 DsDataType,
235 EDataType,
236 AElementwiseOperation,
237 BElementwiseOperation,
238 CDEElementwiseOperation>;
239
240 static constexpr index_t NumDTensor = DsDataType::Size();
241
242#ifndef __HIPCC_RTC__
243
244 explicit DeviceGemmMultipleDSplitKWrapper(std::unique_ptr<DeviceOp> p_op)
245 : p_op_(std::move(p_op))
246 {
247 }
248
249 bool IsSupportedArgument(const BaseArgument* p_arg) override
250 {
251 return p_op_->IsSupportedArgument(p_arg);
252 }
253 std::unique_ptr<BaseArgument>
254 MakeArgumentPointer(const void* p_a,
255 const void* p_b,
256 std::array<const void*, NumDTensor> p_ds,
257 void* p_e,
258 ck::index_t M,
259 ck::index_t N,
260 ck::index_t K,
261 ck::index_t StrideA,
262 ck::index_t StrideB,
263 std::array<ck::index_t, NumDTensor> StrideDs,
264 ck::index_t StrideE,
265 AElementwiseOperation a_element_op,
266 BElementwiseOperation b_element_op,
267 CDEElementwiseOperation cde_element_op) override
268 {
269 return p_op_->MakeArgumentPointer(p_a,
270 p_b,
271 p_ds,
272 p_e,
273 M,
274 N,
275 K,
276 StrideA,
277 StrideB,
278 StrideDs,
279 StrideE,
280 1, // KBatch
281 a_element_op,
282 b_element_op,
283 cde_element_op);
284 }
285
286 std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
287 {
288 return p_op_->MakeInvokerPointer();
289 }
290
291 std::string GetTypeString() const override { return p_op_->GetTypeString(); }
292
293 private:
294 std::unique_ptr<DeviceOp> p_op_;
295
296#endif // __HIPCC_RTC__
297};
298
299} // namespace device
300} // namespace tensor_operation
301} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
STL namespace.
Definition device_base.hpp:197
Definition device_gemm_multiple_d.hpp:36
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d.hpp:37
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d.hpp:126
Definition device_gemm_multiple_d.hpp:80
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d.hpp:81
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition device_gemm_multiple_d.hpp:249
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition device_gemm_multiple_d.hpp:254
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition device_gemm_multiple_d.hpp:286
DeviceGemmMultipleDSplitKWrapper(std::unique_ptr< DeviceOp > p_op)
Definition device_gemm_multiple_d.hpp:244
std::string GetTypeString() const override
Definition device_gemm_multiple_d.hpp:291
DeviceGemmMultipleDSplitK< ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation > DeviceOp
Definition device_gemm_multiple_d.hpp:228
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d.hpp:240
Definition device_gemm_multiple_d.hpp:167
static constexpr index_t NumDTensor
Definition device_gemm_multiple_d.hpp:168
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_a_scale, const void *p_b, const void *p_b_scale, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideAScale, ck::index_t StrideB, ck::index_t StrideBScale, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0