threadwise_tensor_slice_transfer_v6r3.hpp Source File

threadwise_tensor_slice_transfer_v6r3.hpp Source File#

Composable Kernel: threadwise_tensor_slice_transfer_v6r3.hpp Source File
threadwise_tensor_slice_transfer_v6r3.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
10
11namespace ck {
12
13// Do following things to avoid "alloca" in LLVM-IR, which would cause scratch memory
14// and sometimes useless instructions:
15// 1. Don't save a reference to tensor descriptor in class, pass in tensor descriptor as argument
16// instead
17// 2. Don't construct a new tensor coordinate everytime when using it, update and reuse the same
18// tensor coordinate instead
19// 3. Don't use a pointer to VGPR buffer, use vector instead
20
21// Assume:
22// 1. src0_desc and dst_desc are not known at compile-time
23// 2. SrcBuffer and DstBuffer are DynamicBuffer
24// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
25template <typename Src0Data,
26 typename Src1Data,
27 typename Src2Data,
28 typename DstData,
29 typename Src0Desc,
30 typename Src1Desc,
31 typename Src2Desc,
32 typename DstDesc,
33 typename ElementwiseOperation,
34 typename SliceLengths,
35 typename DimAccessOrder,
36 index_t VectorDim,
37 index_t ScalarPerVector,
39 bool Src0ResetCoordinateAfterRun,
40 bool Src1ResetCoordinateAfterRun,
41 bool Src2ResetCoordinateAfterRun,
42 bool DstResetCoordinateAfterRun>
44{
45 static constexpr index_t nDim = SliceLengths::Size();
46
48
49 using Src0Coord = decltype(make_tensor_coordinate(Src0Desc{}, Index{}));
50 using Src1Coord = decltype(make_tensor_coordinate(Src1Desc{}, Index{}));
51 using Src2Coord = decltype(make_tensor_coordinate(Src2Desc{}, Index{}));
52 using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
53
54 static constexpr auto I0 = Number<0>{};
55
56 __device__ constexpr ThreadwiseTensorSliceTransfer_v6r3(const Src0Desc& src0_desc,
57 const Index& src0_slice_origin,
58 const Src1Desc& src1_desc,
59 const Index& src1_slice_origin,
60 const Src2Desc& src2_desc,
61 const Index& src2_slice_origin,
62 const DstDesc& dst_desc,
63 const Index& dst_slice_origin,
64 const ElementwiseOperation& element_op)
65 : src0_coord_(make_tensor_coordinate(src0_desc, src0_slice_origin)),
66 src1_coord_(make_tensor_coordinate(src1_desc, src1_slice_origin)),
67 src2_coord_(make_tensor_coordinate(src2_desc, src2_slice_origin)),
68 dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin)),
69 element_op_(element_op)
70 {
71 static_assert(SliceLengths::At(Number<VectorDim>{}) % ScalarPerVector == 0,
72 "wrong! cannot evenly divide");
73 }
74
75 __device__ void SetSrc0SliceOrigin(const Src0Desc& src0_desc,
76 const Index& src0_slice_origin_idx)
77 {
78 src0_coord_ = make_tensor_coordinate(src0_desc, src0_slice_origin_idx);
79 }
80
81 __device__ void SetSrc1SliceOrigin(const Src1Desc& src1_desc,
82 const Index& src1_slice_origin_idx)
83 {
84 src1_coord_ = make_tensor_coordinate(src1_desc, src1_slice_origin_idx);
85 }
86
87 __device__ void SetSrc2SliceOrigin(const Src2Desc& src2_desc,
88 const Index& src2_slice_origin_idx)
89 {
90 src2_coord_ = make_tensor_coordinate(src2_desc, src2_slice_origin_idx);
91 }
92
93 __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
94 {
95 dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
96 }
97
98 template <typename Src0Buffer, typename Src1Buffer, typename Src2Buffer, typename DstBuffer>
99 __device__ void Run(const Src0Desc& src0_desc,
100 const Src0Buffer& src0_buf,
101 const Src1Desc& src1_desc,
102 const Src1Buffer& src1_buf,
103 const Src2Desc& src2_desc,
104 const Src2Buffer& src2_buf,
105 const DstDesc& dst_desc,
106 DstBuffer& dst_buf)
107 {
108 // scalar per access on each dim
109 // TODO: don't use lambda_scalar_per_access
110 constexpr auto scalar_per_access = generate_sequence(
112
113 using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
114 DimAccessOrder,
115 remove_cv_t<decltype(scalar_per_access)>>;
116
117 constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
118
119 // loop over space-filling curve
120 static_for<0, num_access, 1>{}([&](auto idx_1d) {
121 using src0_vector_type = vector_type_maker_t<Src0Data, ScalarPerVector>;
122 using src0_vector_t = typename src0_vector_type::type;
123
124 using src1_vector_type = vector_type_maker_t<Src1Data, ScalarPerVector>;
125 using src1_vector_t = typename src1_vector_type::type;
126
127 using src2_vector_type = vector_type_maker_t<Src2Data, ScalarPerVector>;
128 using src2_vector_t = typename src2_vector_type::type;
129
130 using dst_vector_type = vector_type_maker_t<DstData, ScalarPerVector>;
131 using dst_vector_t = typename dst_vector_type::type;
132
133 const bool is_src0_valid =
135
136 const bool is_src1_valid =
138
139 const bool is_src2_valid =
141
142 // copy data from src0_buf into src0_vector_container
143 auto src0_vector_container = src0_vector_type{
144 src0_buf.template Get<src0_vector_t>(src0_coord_.GetOffset(), is_src0_valid)};
145
146 auto src1_vector_container = src1_vector_type{
147 src1_buf.template Get<src1_vector_t>(src1_coord_.GetOffset(), is_src1_valid)};
148
149 auto src2_vector_container = src2_vector_type{
150 src2_buf.template Get<src2_vector_t>(src2_coord_.GetOffset(), is_src2_valid)};
151
152 auto dst_vector_container = dst_vector_type{};
153
154 // apply pointwise operation
156 element_op_(dst_vector_container.template AsType<DstData>()(i),
157 src0_vector_container.template AsType<Src0Data>()[i],
158 src1_vector_container.template AsType<Src1Data>()[i],
159 src2_vector_container.template AsType<Src2Data>()[i]);
160 });
161
162 const bool is_dst_valid =
164
165 dst_buf.template Update<DstInMemOp, dst_vector_t>(
166 dst_coord_.GetOffset(),
167 is_dst_valid,
168 dst_vector_container.template AsType<dst_vector_t>()[I0]);
169
170 // move coordinate
171 if constexpr(idx_1d.value != num_access - 1)
172 {
173 constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
175 src0_desc, src0_coord_, make_tensor_coordinate_step(src0_desc, forward_step));
177 src1_desc, src1_coord_, make_tensor_coordinate_step(src1_desc, forward_step));
179 src2_desc, src2_coord_, make_tensor_coordinate_step(src2_desc, forward_step));
181 dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
182 }
183 });
184
185 // move coordinate back to slice origin (or not)
186 if constexpr(Src0ResetCoordinateAfterRun)
187 {
188 const auto src0_reset_step =
190
191 move_tensor_coordinate(src0_desc, src0_coord_, src0_reset_step);
192 }
193
194 if constexpr(Src1ResetCoordinateAfterRun)
195 {
196 const auto src1_reset_step =
198
199 move_tensor_coordinate(src1_desc, src1_coord_, src1_reset_step);
200 }
201
202 if constexpr(Src2ResetCoordinateAfterRun)
203 {
204 const auto src2_reset_step =
206
207 move_tensor_coordinate(src2_desc, src2_coord_, src2_reset_step);
208 }
209
210 if constexpr(DstResetCoordinateAfterRun)
211 {
212 const auto dst_reset_step =
214
215 move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
216 }
217 }
218
219 __device__ static constexpr auto GetCoordinateResetStep()
220 {
221 constexpr auto scalar_per_access = generate_sequence(
223
224 using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
225 DimAccessOrder,
226 remove_cv_t<decltype(scalar_per_access)>>;
227
228 constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
229 if constexpr(num_access == 0)
230 {
231 return typename SpaceFillingCurve::Index{};
232 }
233 else
234 {
235 constexpr auto reset_step =
237
238 return reset_step;
239 }
240 }
241
242 // src_slice_origin_step_idx need to be known at compile-time, for performance reason
243 __device__ void MoveSrc0SliceWindow(const Src0Desc& src0_desc,
244 const Index& src0_slice_origin_step_idx)
245 {
246 // if src coord was not reset by RunRead(), then need to adjust the step here
247 const auto adjusted_step_idx = Src0ResetCoordinateAfterRun
248 ? src0_slice_origin_step_idx
249 : src0_slice_origin_step_idx + GetCoordinateResetStep();
250
251 // is it OK to construct a new step every time?
252 const auto adjusted_step = make_tensor_coordinate_step(src0_desc, adjusted_step_idx);
253
254 move_tensor_coordinate(src0_desc, src0_coord_, adjusted_step);
255 }
256
257 // src_slice_origin_step_idx need to be known at compile-time, for performance reason
258 __device__ void MoveSrc1SliceWindow(const Src1Desc& src1_desc,
259 const Index& src1_slice_origin_step_idx)
260 {
261 // if src coord was not reset by RunRead(), then need to adjust the step here
262 const auto adjusted_step_idx = Src1ResetCoordinateAfterRun
263 ? src1_slice_origin_step_idx
264 : src1_slice_origin_step_idx + GetCoordinateResetStep();
265
266 // is it OK to construct a new step every time?
267 const auto adjusted_step = make_tensor_coordinate_step(src1_desc, adjusted_step_idx);
268
269 move_tensor_coordinate(src1_desc, src1_coord_, adjusted_step);
270 }
271
272 // src_slice_origin_step_idx need to be known at compile-time, for performance reason
273 __device__ void MoveSrc2SliceWindow(const Src2Desc& src2_desc,
274 const Index& src2_slice_origin_step_idx)
275 {
276 // if src coord was not reset by RunRead(), then need to adjust the step here
277 const auto adjusted_step_idx = Src2ResetCoordinateAfterRun
278 ? src2_slice_origin_step_idx
279 : src2_slice_origin_step_idx + GetCoordinateResetStep();
280
281 // is it OK to construct a new step every time?
282 const auto adjusted_step = make_tensor_coordinate_step(src2_desc, adjusted_step_idx);
283
284 move_tensor_coordinate(src2_desc, src2_coord_, adjusted_step);
285 }
286
287 // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
288 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
289 const Index& dst_slice_origin_step_idx)
290 {
291 // if dst coord was not reset by Run(), then need to adjust the step here
292 const auto adjusted_step_idx = DstResetCoordinateAfterRun
293 ? dst_slice_origin_step_idx
294 : dst_slice_origin_step_idx + GetCoordinateResetStep();
295
296 // is it OK to construct a new step every time?
297 const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
298
299 move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
300 }
301
302 private:
303 Src0Coord src0_coord_;
304 Src1Coord src1_coord_;
305 Src2Coord src2_coord_;
306 DstCoord dst_coord_;
307 const ElementwiseOperation element_op_;
308};
309
310} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc &, const VisibleIndex &idx_diff_visible, UpdateLowerIndexHack)
Definition tensor_description/tensor_descriptor.hpp:444
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const TensorCoordStep &coord_step)
Definition tensor_description/tensor_descriptor.hpp:508
InMemoryDataOperationEnum
Definition ck.hpp:277
__host__ __device__ constexpr bool coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition tensor_description/tensor_descriptor.hpp:560
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto generate_sequence(F, Number< N >)
Definition sequence_helper.hpp:18
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition tensor_description/tensor_descriptor.hpp:407
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
typename vector_type_maker< T, N >::type vector_type_maker_t
Definition dtype_vector.hpp:54
Definition tensor_space_filling_curve.hpp:20
static __device__ __host__ constexpr auto GetStepBetween(Number< AccessIdx1dBegin >, Number< AccessIdx1dEnd >)
Definition tensor_space_filling_curve.hpp:52
__host__ static __device__ constexpr index_t GetNumOfAccess()
Definition tensor_space_filling_curve.hpp:41
static __device__ __host__ constexpr auto GetForwardStep(Number< AccessIdx1d >)
Definition tensor_space_filling_curve.hpp:66
MultiIndex< nDim > Index
Definition tensor_space_filling_curve.hpp:23
__device__ void MoveSrc0SliceWindow(const Src0Desc &src0_desc, const Index &src0_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v6r3.hpp:243
__device__ void MoveSrc1SliceWindow(const Src1Desc &src1_desc, const Index &src1_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v6r3.hpp:258
__device__ void SetDstSliceOrigin(const DstDesc &dst_desc, const Index &dst_slice_origin_idx)
Definition threadwise_tensor_slice_transfer_v6r3.hpp:93
__device__ void SetSrc0SliceOrigin(const Src0Desc &src0_desc, const Index &src0_slice_origin_idx)
Definition threadwise_tensor_slice_transfer_v6r3.hpp:75
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &dst_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v6r3.hpp:288
__device__ constexpr ThreadwiseTensorSliceTransfer_v6r3(const Src0Desc &src0_desc, const Index &src0_slice_origin, const Src1Desc &src1_desc, const Index &src1_slice_origin, const Src2Desc &src2_desc, const Index &src2_slice_origin, const DstDesc &dst_desc, const Index &dst_slice_origin, const ElementwiseOperation &element_op)
Definition threadwise_tensor_slice_transfer_v6r3.hpp:56
static __device__ constexpr auto GetCoordinateResetStep()
Definition threadwise_tensor_slice_transfer_v6r3.hpp:219
__device__ void MoveSrc2SliceWindow(const Src2Desc &src2_desc, const Index &src2_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer_v6r3.hpp:273
__device__ void SetSrc2SliceOrigin(const Src2Desc &src2_desc, const Index &src2_slice_origin_idx)
Definition threadwise_tensor_slice_transfer_v6r3.hpp:87
__device__ void SetSrc1SliceOrigin(const Src1Desc &src1_desc, const Index &src1_slice_origin_idx)
Definition threadwise_tensor_slice_transfer_v6r3.hpp:81
__device__ void Run(const Src0Desc &src0_desc, const Src0Buffer &src0_buf, const Src1Desc &src1_desc, const Src1Buffer &src1_buf, const Src2Desc &src2_desc, const Src2Buffer &src2_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer_v6r3.hpp:99
Definition threadwise_tensor_slice_transfer_util.hpp:20
Definition functional2.hpp:33