threadwise_tensor_slice_transfer.hpp Source File

threadwise_tensor_slice_transfer.hpp Source File#

Composable Kernel: threadwise_tensor_slice_transfer.hpp Source File
threadwise_tensor_slice_transfer.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
10
12
14
15namespace ck {
16// Assume:
17// 1. src:
18// 1. SrcDesc is known at compile-time
19// 2. SrcBuffer is StaticBuffer
20// 3. SrcSliceOrginIdx is known at compile-time
21// 2. dst:
22// 1. DstDesc is not known at compile-time
23// 2. DstBuffer is DynamicBuffer
24// 3. DstSliceOrginIdx is not known at compile time
25template <typename SrcData,
26 typename DstData,
27 typename SrcDesc,
28 typename DstDesc,
29 typename ElementwiseOperation,
30 typename SliceLengths,
31 typename DimAccessOrder,
32 index_t DstVectorDim,
33 index_t DstScalarPerVector,
35 index_t DstScalarStrideInVector,
36 bool DstResetCoordinateAfterRun,
37 typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
39{
40 static constexpr index_t nDim = SliceLengths::Size();
41
43
44 using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
45
46 using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
47
48 __device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc& dst_desc,
49 const Index& dst_slice_origin_idx,
50 const ElementwiseOperation& element_op)
51 : dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin_idx)),
52 element_op_{element_op}
53 {
54 static_assert(SrcDesc::IsKnownAtCompileTime(),
55 "wrong! SrcDesc need to known at compile-time");
56 static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
57 "wrong! Not divisible");
58 }
59
60 __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
61 {
62 dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
63 }
64
65 template <typename SrcSliceOriginIdx, typename SrcBuffer, typename DstBuffer>
66 __device__ void Run(const SrcDesc&,
67 const SrcSliceOriginIdx&,
68 const SrcBuffer& src_buf,
69 const DstDesc& dst_desc,
70 DstBuffer& dst_buf)
71 {
72 static_assert(SrcDesc::IsKnownAtCompileTime(),
73 "wrong! SrcDesc need to known at compile-time");
74
76 "wrong! SrcSliceOrigin need to known at compile-time");
77
78 static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer");
79
80 // SrcDesc and src_slice_origin_idx are known at compile-time
81 constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
82 constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
83
84 // scalar per access on each dim
85 // TODO: don't use lambda_scalar_per_access
86 constexpr auto dst_scalar_per_access = generate_sequence(
88
89 constexpr auto dst_scalar_step_in_vector =
91
92 using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
93 DimAccessOrder,
94 remove_cv_t<decltype(dst_scalar_per_access)>>;
95
96 // TODO: Use SpaceFillingCurve::ScalarsPerAccess instread of DstScalarPerVector?
97 static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
98 "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
101
102 constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
103
104 static_for<0, num_access, 1>{}([&](auto idx_1d) {
105 constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
106
107 // copy data from src_buf into dst_vector
108 // TODO: It's a hack here to use \p dst_scalar_step_in_vector. Use SpaceFillingCurve?
110 constexpr index_t src_offset = src_desc.CalculateOffset(
111 src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
112
113 DstData v;
114
115 // apply element-wise operation
116 element_op_(v, src_buf[Number<src_offset>{}]);
117
118 dst_vector.template AsType<DstData>()(i) = v;
119 });
120
121 const bool is_dst_valid =
123
124 // copy data from dst_vector into dst_buf
125 dst_buf.template Update<DstInMemOp, dst_vector_t>(
126 dst_coord_.GetOffset(),
127 is_dst_valid,
128 dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
129
130 if constexpr(idx_1d.value != num_access - 1)
131 {
132 constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
133
135 dst_desc, dst_coord_, make_tensor_coordinate_step(dst_desc, forward_step));
136 }
137 });
138
139 // move dst coordinate back to slice origin (or not)
140 if constexpr(DstResetCoordinateAfterRun)
141 {
142 const auto dst_reset_step =
144
145 move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
146 }
147 }
148
149 __device__ static constexpr auto GetDstCoordinateResetStep()
150 {
151 constexpr auto dst_scalar_per_access = generate_sequence(
153
154 using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
155 DimAccessOrder,
156 remove_cv_t<decltype(dst_scalar_per_access)>>;
157
158 constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
159 if constexpr(num_access == 0)
160 {
161 return typename SpaceFillingCurve::Index{};
162 }
163 else
164 {
165 constexpr auto reset_step =
167
168 return reset_step;
169 }
170 }
171
172 // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
173 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
174 const Index& dst_slice_origin_step_idx)
175 {
176 // if dst coord was not reset by Run(), then need to adjust the step here
177 const auto adjusted_step_idx =
178 DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
179 : dst_slice_origin_step_idx + GetDstCoordinateResetStep();
180
181 // is it OK to construct a new step every time?
182 const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
183
184 move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
185 }
186
187 private:
188 DstCoord dst_coord_;
189 const ElementwiseOperation element_op_;
190}; // namespace ThreadwiseTensorSliceTransfer_v1r3
191
221template <typename SrcData,
222 typename DstData,
223 typename SrcDesc,
224 typename DstDesc,
225 typename SliceLengths,
226 typename DimAccessOrder,
227 index_t SrcVectorDim,
228 index_t SrcScalarPerVector,
229 index_t SrcScalarStrideInVector,
230 bool SrcResetCoordinateAfterRun,
231 bool InvalidElementAsNaN = false,
232 typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
234{
235 static_assert((InvalidElementAsNaN && !ck::is_integral<DstData>::value) ||
236 (!InvalidElementAsNaN),
237 "Filling invalid element as NaN is only for floating point types");
238
239 static constexpr index_t nDim = SliceLengths::Size();
240
242
243 using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
244
245 using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
246
247 static constexpr index_t PackedSize = []() {
249 return 2;
250 else
251 return 1;
252 }();
253
254 __device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc& src_desc,
255 const Index& src_slice_origin_idx)
256 : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx))
257 {
258 static_assert(DstDesc::IsKnownAtCompileTime(),
259 "wrong! SrcDesc need to known at compile-time");
260 static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
261 "wrong! Not divisible");
262
265 {
266 static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
267 }
268 }
269
270 __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
271 {
272 src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
273 }
274
275 template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx>
276 __device__ void Run(const SrcDesc& src_desc,
277 const SrcBuffer& src_buf,
278 const DstDesc&,
279 const DstSliceOriginIdx&,
280 DstBuffer& dst_buf)
281 {
282 static_assert(DstDesc::IsKnownAtCompileTime(),
283 "wrong! DstDesc need to known at compile-time");
284
286 "wrong! DstSliceOrigin need to known at compile-time");
287
288 static_assert(
290 "wrong! inconsistent type");
291
292 // DstDesc and dst_slice_origin_idx are known at compile-time
293 constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
294 constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
295
296 // scalar per access on each dim
297 // TODO: don't use lambda_scalar_per_access
298 constexpr auto src_scalar_per_access = generate_sequence(
300
301 constexpr auto src_scalar_step_in_vector =
303
304 using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
305 DimAccessOrder,
306 remove_cv_t<decltype(src_scalar_per_access)>>;
307
308 // loop over tensor and copy
309 constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
310
311 static_for<0, num_access, 1>{}([&](auto idx_1d) {
312 typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type src_vector;
313
314 using src_vector_t =
315 typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type::type;
316 constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
317
318 const bool is_src_valid =
320
321 // copy data from src_buf into src_vector
322 src_vector.template AsType<src_vector_t>()(Number<0>{}) =
323 src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize,
324 is_src_valid);
325
326 // copy data from src_vector into dst_buf
327 static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
328 constexpr index_t dst_offset =
329 dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) + src_data_idx +
330 i * src_scalar_step_in_vector);
331
332 if constexpr(InvalidElementAsNaN)
333 {
334 dst_buf(Number<dst_offset>{}) =
335 is_src_valid
336 ? type_convert<DstData>(src_vector.template AsType<SrcData>()[i])
338 }
339 else
340 {
341 dst_buf(Number<dst_offset>{}) =
342 type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
343 }
344 });
345
346 if constexpr(idx_1d.value != num_access - 1)
347 {
348 constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
349
351 src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
352 }
353 });
354
355 // move src coordinate back to slice origin (or not)
356 if constexpr(SrcResetCoordinateAfterRun)
357 {
358 const auto src_reset_step =
360
361 move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
362 }
363 }
364
365 __device__ static constexpr auto GetSrcCoordinateResetStep()
366 {
367 constexpr auto src_scalar_per_access = generate_sequence(
369
370 using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
371 DimAccessOrder,
372 remove_cv_t<decltype(src_scalar_per_access)>>;
373
374 constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
375 if constexpr(num_access == 0)
376 {
377 return typename SpaceFillingCurve::Index{};
378 }
379 else
380 {
381 constexpr auto reset_step =
383
384 return reset_step;
385 }
386 }
387
388 // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
389 __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
390 const Index& src_slice_origin_step_idx)
391 {
392 // if src coord was not reset by Run(), then need to adjust the step here
393 const auto adjusted_step_idx =
394 SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
395 : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
396
397 // is it OK to construct a new step every time?
398 const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
399
400 move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
401 }
402
403 // src_slice_origin_step_idx need to be known at compile-time, for performance reason
404 template <typename SrcMoveSliceWindowStepHack>
405 __device__ void
406 MoveSrcSliceWindow(const SrcDesc& src_desc,
407 const Index& src_slice_origin_step_idx,
408 const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
409 {
410 // if src coord was not reset by RunRead(), then need to adjust the step here
411 const auto adjusted_step_idx =
412 SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
413 : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
414
415 // is it OK to construct a new step every time?
416 const auto adjusted_step = make_tensor_coordinate_step(
417 src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
418
419 move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
420 }
421
422 private:
423 SrcCoord src_coord_;
424}; // namespace ck
425
426template <typename SrcData,
427 typename DstData,
428 typename SrcDesc,
429 typename DstDesc,
430 typename SliceLengths,
431 typename DimAccessOrder,
432 index_t SrcVectorDim,
433 index_t SrcScalarPerVector,
434 index_t SrcScalarStrideInVector,
435 bool SrcResetCoordinateAfterRun,
436 index_t scale_gather_num,
437 bool InvalidElementAsNaN = false,
438 typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
440{
441 static_assert((InvalidElementAsNaN && !ck::is_integral<DstData>::value) ||
442 (!InvalidElementAsNaN),
443 "Filling invalid element as NaN is only for floating point types");
444
445 static constexpr index_t nDim = SliceLengths::Size();
446
448
449 using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
450
451 using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
452
453 static constexpr index_t PackedSize = []() {
455 return 2;
456 else
457 return 1;
458 }();
459
461 const SrcDesc& src_desc,
462 const Index& src_slice_origin_idx,
463 const StaticallyIndexedArray<index_t, scale_gather_num>& scale_gather_offsets)
464 : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin_idx)),
465 scale_gather_offsets_(scale_gather_offsets)
466 {
467 static_assert(DstDesc::IsKnownAtCompileTime(),
468 "wrong! SrcDesc need to known at compile-time");
469 static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
470 "wrong! Not divisible");
471
473 {
474 static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
475 }
476 }
477
478 __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
479 {
480 auto adjusted_origin_idx = [&]() {
481 Index idx;
482
484 [&](auto i) { idx(i) = i.value == 0 ? 0 : src_slice_origin_idx[Number<i>{}]; });
485
486 return idx;
487 }();
488
489 src_coord_ = make_tensor_coordinate(src_desc, adjusted_origin_idx);
490 }
491
492 template <typename SrcBuffer, typename DstBuffer, typename DstSliceOriginIdx>
493 __device__ void Run(const SrcDesc& src_desc,
494 const SrcBuffer& src_buf,
495 const DstDesc&,
496 const DstSliceOriginIdx&,
497 DstBuffer& dst_buf)
498 {
499 static_assert(DstDesc::IsKnownAtCompileTime(),
500 "wrong! DstDesc need to known at compile-time");
501
503 "wrong! DstSliceOrigin need to known at compile-time");
504
505 static_assert(
507 "wrong! inconsistent type");
508
509 // DstDesc and dst_slice_origin_idx are known at compile-time
510 constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
511 constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{};
512
513 // scalar per access on each dim
514 // TODO: don't use lambda_scalar_per_access
515 constexpr auto src_scalar_per_access = generate_sequence(
517
518 constexpr auto src_scalar_step_in_vector =
520
521 using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
522 DimAccessOrder,
523 remove_cv_t<decltype(src_scalar_per_access)>>;
524
525 // loop over tensor and copy
526 constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
527
528 static_for<0, scale_gather_num, 1>{}([&](auto gather_idx) {
529 constexpr auto current_dst_origin =
530 to_multi_index(dst_slice_origin_idx) + make_multi_index(gather_idx, 0);
531
532 static_for<0, num_access, 1>{}([&](auto idx_1d) {
533 typename vector_type_maker<SrcData, SrcScalarPerVector / PackedSize>::type
534 src_vector;
535
536 using src_vector_t =
537 typename vector_type_maker<SrcData,
538 SrcScalarPerVector / PackedSize>::type::type;
539 constexpr auto src_data_idx = SpaceFillingCurve::GetIndex(idx_1d);
540
541 const bool is_src_valid =
543 src_coord_);
544
545 // copy data from src_buf into src_vector
546 src_vector.template AsType<src_vector_t>()(Number<0>{}) =
547 src_buf.template Get<src_vector_t>(src_coord_.GetOffset() / PackedSize +
548 scale_gather_offsets_(gather_idx),
549 is_src_valid);
550
551 // copy data from src_vector into dst_buf
552 static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
553 constexpr index_t dst_offset =
554 dst_desc.CalculateOffset(to_multi_index(dst_slice_origin_idx) +
555 src_data_idx + i * src_scalar_step_in_vector);
556 constexpr auto full_dst_offset =
557 dst_desc.CalculateOffset(current_dst_origin) + dst_offset;
558
559 if constexpr(InvalidElementAsNaN)
560 {
561 dst_buf(full_dst_offset) =
562 is_src_valid
563 ? type_convert<DstData>(src_vector.template AsType<SrcData>()[i])
565 }
566 else
567 {
568 dst_buf(Number<full_dst_offset>{}) =
569 type_convert<DstData>(src_vector.template AsType<SrcData>()[i]);
570 }
571 });
572
573 if constexpr(idx_1d.value != num_access - 1)
574 {
575 constexpr auto forward_step = SpaceFillingCurve::GetForwardStep(idx_1d);
576
578 src_desc, src_coord_, make_tensor_coordinate_step(src_desc, forward_step));
579 }
580 });
581 });
582
583 // move src coordinate back to slice origin (or not)
584 if constexpr(SrcResetCoordinateAfterRun)
585 {
586 const auto src_reset_step =
588
589 move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
590 }
591 }
592
593 __device__ static constexpr auto GetSrcCoordinateResetStep()
594 {
595 constexpr auto src_scalar_per_access = generate_sequence(
597
598 using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
599 DimAccessOrder,
600 remove_cv_t<decltype(src_scalar_per_access)>>;
601
602 constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
603 if constexpr(num_access == 0)
604 {
605 return typename SpaceFillingCurve::Index{};
606 }
607 else
608 {
609 constexpr auto reset_step =
611
612 return reset_step;
613 }
614 }
615
616 // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
617 __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
618 const Index& src_slice_origin_step_idx)
619 {
620 // if src coord was not reset by Run(), then need to adjust the step here
621 const auto adjusted_step_idx =
622 SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
623 : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
624
625 // is it OK to construct a new step every time?
626 const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
627
628 move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
629 }
630
631 // src_slice_origin_step_idx need to be known at compile-time, for performance reason
632 template <typename SrcMoveSliceWindowStepHack>
633 __device__ void
634 MoveSrcSliceWindow(const SrcDesc& src_desc,
635 const Index& src_slice_origin_step_idx,
636 const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
637 {
638 // if src coord was not reset by RunRead(), then need to adjust the step here
639 const auto adjusted_step_idx =
640 SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
641 : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
642
643 // is it OK to construct a new step every time?
644 const auto adjusted_step = make_tensor_coordinate_step(
645 src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
646
647 move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
648 }
649
650 private:
651 SrcCoord src_coord_;
653}; // namespace ck
654
655// Assume:
656// 1. src_desc and dst_desc are not known at compile-time
657// 2. SrcBuffer and DstBuffer are DynamicBuffer
658// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
659// 4. Use thread buffer
660template <typename SliceLengths,
661 InMemoryDataOperationEnum DstInMemOp,
662 typename SrcData,
663 typename DstData,
664 typename SrcDesc,
665 typename DstDesc,
666 typename SrcDimAccessOrder,
667 typename DstDimAccessOrder,
668 index_t SrcVectorDim,
669 index_t DstVectorDim,
670 index_t SrcScalarPerVector,
671 index_t DstScalarPerVector,
672 index_t SrcScalarStrideInVector,
673 index_t DstScalarStrideInVector,
674 bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
675 // RunRead(), will be fused with MoveSrcSliceWindow to
676 // save addr computation
677 bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
678 // RunWrite(), will be fused with MoveDstSliceWindow to
679 // save addr computation
681{
682 static constexpr index_t nDim = SliceLengths::Size();
684
685 using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
686 using DstCoord = decltype(make_tensor_coordinate(DstDesc{}, Index{}));
687
688 using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
689 using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{}));
690
691 __device__ constexpr ThreadwiseTensorSliceTransfer_v3(const SrcDesc& src_desc,
692 const Index& src_slice_origin,
693 const DstDesc& dst_desc,
694 const Index& dst_slice_origin)
695 : src_coord_(make_tensor_coordinate(src_desc, src_slice_origin)),
696 dst_coord_(make_tensor_coordinate(dst_desc, dst_slice_origin))
697 {
698 static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
699 "wrong! Not divisible");
700 static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
701 "wrong! Not divisible");
702 }
703
704 __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
705 {
706 src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx);
707 }
708
709 __device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
710 {
711 dst_coord_ = make_tensor_coordinate(dst_desc, dst_slice_origin_idx);
712 }
713
714 template <typename SrcBuffer, typename SrcStepHacks>
715 __device__ void
716 RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf, const SrcStepHacks& src_step_hacks)
717 {
718 static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
719 SrcBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
720 "wrong!");
721
722 static_assert(
724 "wrong! SrcBuffer and SrcData data type are inconsistent");
725
726 constexpr auto I0 = Number<0>{};
727 constexpr auto I1 = Number<1>{};
728
729 // scalar per access on each dim
730 // TODO: don't use lambda_scalar_per_access
731 constexpr auto src_scalar_per_access = generate_sequence(
733
734 constexpr auto src_scalar_step_in_vector =
736
737 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
738
739 constexpr auto src_dim_access_order = SrcDimAccessOrder{};
740
741 constexpr auto ordered_src_access_lengths =
742 container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
743
744 // make forward steps
745 const auto src_forward_steps = generate_tuple(
746 [&](auto i) {
747 Index forward_step_idx;
748
749 static_for<0, nDim, 1>{}([&](auto j) {
750 forward_step_idx(j) = (i.value == j.value) ? src_scalar_per_access[i] : 0;
751 });
752
754 src_desc, forward_step_idx, src_step_hacks[I0][i]);
755 },
756 Number<nDim>{});
757
758 // make backward steps
759 const auto src_backward_steps = generate_tuple(
760 [&](auto i) {
761 Index backward_step_idx;
762
763 static_for<0, nDim, 1>{}([&](auto j) {
764 backward_step_idx(j) = (i.value == j.value) ? -src_scalar_per_access[i] : 0;
765 });
766
768 src_desc, backward_step_idx, src_step_hacks[I1][i]);
769 },
770 Number<nDim>{});
771
772 // loop over tensor and copy
773 static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
774 // judge move forward or move backward
775 constexpr auto forward_sweep = [&]() {
777
778 forward_sweep_(I0) = true;
779
780 static_for<1, nDim, 1>{}([&](auto i) {
781 index_t tmp = ordered_src_access_idx[I0];
782
783 static_for<1, i, 1>{}([&](auto j) {
784 tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
785 });
786
787 forward_sweep_(i) = tmp % 2 == 0;
788 });
789
790 return forward_sweep_;
791 }();
792
793 // calculate src data index
794 constexpr auto src_data_idx = [&]() {
795 Index ordered_idx;
796
797 static_for<0, nDim, 1>{}([&](auto i) {
798 ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
799 : ordered_src_access_lengths[i] - 1 -
800 ordered_src_access_idx[i];
801 });
802
803 return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
804 src_scalar_per_access;
805 }();
806
808
809 using src_vector_t = typename decltype(src_tmp_vector)::type;
810
811 const bool is_src_valid =
813
814 // copy data from src_buf to src_tmp_vector
815 src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
816 src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid);
817
818 // copy data from src_tmp_vector to buffer_
820 constexpr index_t buffer_offset =
821 buffer_desc_.CalculateOffset(src_data_idx + i * src_scalar_step_in_vector);
822
823 buffer_(Number<buffer_offset>{}) = src_tmp_vector.template AsType<SrcData>()[i];
824 });
825
826 constexpr auto move_on_dim = [&]() constexpr {
828
829 static_for<0, nDim, 1>{}([&](auto i) {
830 move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
831
832 static_for<i + 1, nDim, 1>{}([&](auto j) {
833 move_on_dim_(i) &=
834 ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
835 });
836 });
837
838 return move_on_dim_;
839 }();
840
841 // move
842 static_for<0, nDim, 1>{}([&](auto i) {
843 if constexpr(move_on_dim[i])
844 {
845 if constexpr(forward_sweep[i])
846 {
848 src_desc, src_coord_, src_forward_steps[src_dim_access_order[i]]);
849 }
850 else
851 {
853 src_desc, src_coord_, src_backward_steps[src_dim_access_order[i]]);
854 }
855 }
856 });
857 });
858
859 // move src coordinate back to slice origin (or not)
860 if constexpr(SrcResetCoordinateAfterRun)
861 {
862 const auto src_reset_step =
864
865 move_tensor_coordinate(src_desc, src_coord_, src_reset_step);
866 }
867 }
868
869 template <typename DstBuffer, typename DstStepHacks>
870 __device__ void
871 RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf, const DstStepHacks& dst_step_hacks)
872 {
873 static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Global or
874 DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
875 "wrong!");
876
877 static_assert(
879 "wrong! SrcBuffer or DstBuffer data type is wrong");
880
881 constexpr auto I0 = Number<0>{};
882 constexpr auto I1 = Number<1>{};
883
884 // src scalar per access on each dim
885 // TODO: don't use this
886 constexpr auto dst_scalar_per_access = generate_sequence(
888
889 constexpr auto dst_scalar_step_in_vector =
891
892 constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
893
894 constexpr auto dst_dim_access_order = DstDimAccessOrder{};
895
896 constexpr auto ordered_dst_access_lengths =
897 container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
898
899 // make forward steps
900 const auto dst_forward_steps = generate_tuple(
901 [&](auto i) {
902 Index forward_step_idx;
903
904 static_for<0, nDim, 1>{}([&](auto j) {
905 forward_step_idx(j) = (i.value == j.value) ? dst_scalar_per_access[i] : 0;
906 });
907
909 dst_desc, forward_step_idx, dst_step_hacks[I0][i]);
910 },
911 Number<nDim>{});
912
913 // make backward steps
914 const auto dst_backward_steps = generate_tuple(
915 [&](auto i) {
916 Index backward_step_idx;
917
918 static_for<0, nDim, 1>{}([&](auto j) {
919 backward_step_idx(j) = (i.value == j.value) ? -dst_scalar_per_access[i] : 0;
920 });
921
923 dst_desc, backward_step_idx, dst_step_hacks[I1][i]);
924 },
925 Number<nDim>{});
926
927 // loop over tensor and copy
928 static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
929 // judge move forward or move backward
930 constexpr auto forward_sweep = [&]() {
932
933 forward_sweep_(I0) = true;
934
935 static_for<1, nDim, 1>{}([&](auto i) {
936 index_t tmp = ordered_dst_access_idx[I0];
937
938 static_for<1, i, 1>{}([&](auto j) {
939 tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
940 });
941
942 forward_sweep_(i) = tmp % 2 == 0;
943 });
944
945 return forward_sweep_;
946 }();
947
948 // calculate dst data index
949 constexpr auto dst_data_idx = [&]() {
950 Index ordered_idx;
951
952 static_for<0, nDim, 1>{}([&](auto i) {
953 ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
954 : ordered_dst_access_lengths[i] - 1 -
955 ordered_dst_access_idx[i];
956 });
957
958 return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
959 dst_scalar_per_access;
960 }();
961
963
964 // copy data from buffer_ to dst_tmp_vector
966 constexpr index_t buffer_offset =
967 buffer_desc_.CalculateOffset(dst_data_idx + i * dst_scalar_step_in_vector);
968
969 dst_tmp_vector.template AsType<DstData>()(i) =
971 });
972
973 using dst_vector_t = typename decltype(dst_tmp_vector)::type;
974
975 // copy data from dst_tmp_vector to dst_buf
976 const bool is_dst_valid =
978
979 dst_buf.template Set<dst_vector_t>(
980 dst_coord_.GetOffset(),
981 is_dst_valid,
982 dst_tmp_vector.template AsType<dst_vector_t>()[Number<0>{}]);
983
984 constexpr auto move_on_dim = [&]() constexpr {
986
987 static_for<0, nDim, 1>{}([&](auto i) {
988 move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
989
990 static_for<i + 1, nDim, 1>{}([&](auto j) {
991 move_on_dim_(i) &=
992 ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
993 });
994 });
995
996 return move_on_dim_;
997 }();
998
999 // move
1000 static_for<0, nDim, 1>{}([&](auto i) {
1001 if constexpr(move_on_dim[i])
1002 {
1003 if constexpr(forward_sweep[i])
1004 {
1006 dst_desc, dst_coord_, dst_forward_steps[dst_dim_access_order[i]]);
1007 }
1008 else
1009 {
1011 dst_desc, dst_coord_, dst_backward_steps[dst_dim_access_order[i]]);
1012 }
1013 }
1014 });
1015 });
1016
1017 // move dst coordinate back to slice origin (or not)
1018 if constexpr(DstResetCoordinateAfterRun)
1019 {
1020 const auto dst_reset_step =
1022
1023 move_tensor_coordinate(dst_desc, dst_coord_, dst_reset_step);
1024 }
1025 }
1026
1027 template <typename SrcBuffer>
1028 __device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
1029 {
1030 constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform();
1031
1032 constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
1033
1034 constexpr auto src_step_hacks =
1035 make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
1036 generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
1037
1038 RunRead(src_desc, src_buf, src_step_hacks);
1039 }
1040
1041 template <typename DstBuffer>
1042 __device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
1043 {
1044 constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform();
1045
1046 constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
1047
1048 constexpr auto dst_step_hacks =
1049 make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
1050 generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
1051
1052 RunWrite(dst_desc, dst_buf, dst_step_hacks);
1053 }
1054
1055 __device__ static constexpr auto GetSrcCoordinateResetStep()
1056 {
1057 constexpr auto I0 = Number<0>{};
1058
1059 // scalar per access on each dim
1060 // TODO: don't use lambda_scalar_per_access
1061 constexpr auto src_scalar_per_access = generate_sequence(
1063
1064 constexpr auto src_access_lengths = SliceLengths{} / src_scalar_per_access;
1065
1066 constexpr auto src_dim_access_order = SrcDimAccessOrder{};
1067
1068 constexpr auto ordered_src_access_lengths =
1069 container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
1070
1071 // judge move forward or move backward during the last iteration
1072 constexpr auto forward_sweep = [&]() {
1074
1075 forward_sweep_(I0) = true;
1076
1077 static_for<1, nDim, 1>{}([&](auto i) {
1078 index_t tmp = ordered_src_access_lengths[I0] - 1;
1079
1080 static_for<1, i, 1>{}([&](auto j) {
1081 tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
1082 });
1083
1084 forward_sweep_(i) = tmp % 2 == 0;
1085 });
1086
1087 return forward_sweep_;
1088 }();
1089
1090 // calculate src data index after last iteration in RunRead(), if it has not being reset by
1091 // RunRead()
1092 constexpr auto src_data_idx = [&]() {
1093 Index ordered_idx;
1094
1095 static_for<0, nDim, 1>{}([&](auto i) {
1096 ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
1097 });
1098
1099 return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
1100 src_scalar_per_access;
1101 }();
1102
1103 //
1104 constexpr auto reset_src_data_step = [&]() {
1105 Index reset_src_data_step_;
1106
1107 static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
1108
1109 return reset_src_data_step_;
1110 }();
1111
1112 return reset_src_data_step;
1113 }
1114
1115 __device__ static constexpr auto GetDstCoordinateResetStep()
1116 {
1117 constexpr auto I0 = Number<0>{};
1118
1119 // scalar per access on each dim
1120 // TODO: don't use lambda_scalar_per_access
1121 constexpr auto dst_scalar_per_access = generate_sequence(
1123
1124 constexpr auto dst_access_lengths = SliceLengths{} / dst_scalar_per_access;
1125
1126 constexpr auto dst_dim_access_order = DstDimAccessOrder{};
1127
1128 constexpr auto ordered_dst_access_lengths =
1129 container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
1130
1131 // judge move forward or move backward during the last iteration
1132 constexpr auto forward_sweep = [&]() {
1134
1135 forward_sweep_(I0) = true;
1136
1137 static_for<1, nDim, 1>{}([&](auto i) {
1138 index_t tmp = ordered_dst_access_lengths[I0] - 1;
1139
1140 static_for<1, i, 1>{}([&](auto j) {
1141 tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
1142 });
1143
1144 forward_sweep_(i) = tmp % 2 == 0;
1145 });
1146
1147 return forward_sweep_;
1148 }();
1149
1150 // calculate dst data index after last iteration in RunWrite(), if it has not being reset by
1151 // RunWrite()
1152 constexpr auto dst_data_idx = [&]() {
1153 Index ordered_idx;
1154
1155 static_for<0, nDim, 1>{}([&](auto i) {
1156 ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
1157 });
1158
1159 return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
1160 dst_scalar_per_access;
1161 }();
1162
1163 //
1164 constexpr auto reset_dst_data_step = [&]() {
1165 Index reset_dst_data_step_;
1166
1167 static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
1168
1169 return reset_dst_data_step_;
1170 }();
1171
1172 return reset_dst_data_step;
1173 }
1174
1175 // src_slice_origin_step_idx need to be known at compile-time, for performance reason
1176 __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
1177 const Index& src_slice_origin_step_idx)
1178 {
1179 // if src coord was not reset by RunRead(), then need to adjust the step here
1180 const auto adjusted_step_idx =
1181 SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
1182 : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
1183
1184 // is it OK to construct a new step every time?
1185 const auto adjusted_step = make_tensor_coordinate_step(src_desc, adjusted_step_idx);
1186
1187 move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
1188 }
1189
1190 // src_slice_origin_step_idx need to be known at compile-time, for performance reason
1191 template <typename SrcMoveSliceWindowStepHack>
1192 __device__ void
1193 MoveSrcSliceWindow(const SrcDesc& src_desc,
1194 const Index& src_slice_origin_step_idx,
1195 const SrcMoveSliceWindowStepHack& src_move_slice_window_step_hack)
1196 {
1197 // if src coord was not reset by RunRead(), then need to adjust the step here
1198 const auto adjusted_step_idx =
1199 SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
1200 : src_slice_origin_step_idx + GetSrcCoordinateResetStep();
1201
1202 // is it OK to construct a new step every time?
1203 const auto adjusted_step = make_tensor_coordinate_step(
1204 src_desc, adjusted_step_idx, src_move_slice_window_step_hack);
1205
1206 move_tensor_coordinate(src_desc, src_coord_, adjusted_step);
1207 }
1208 // dst_slice_origin_step_idx need to be known at compile-time, for performance reason
1209 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
1210 const Index& dst_slice_origin_step_idx)
1211 {
1212 // if dst coord was not reset by RunWrite(), then need to adjust the step here
1213 const auto adjusted_step_idx =
1214 DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
1215 : dst_slice_origin_step_idx + GetDstCoordinateResetStep();
1216
1217 // is it OK to construct a new step every time?
1218 const auto adjusted_step = make_tensor_coordinate_step(dst_desc, adjusted_step_idx);
1219
1220 move_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
1221 }
1222
1223 private:
1224 static constexpr auto buffer_desc_ =
1226
1227 static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
1228
1229 StaticBuffer<AddressSpaceEnum::Vgpr, SrcData, buffer_size_, true> buffer_;
1230
1231 SrcCoord src_coord_;
1232 DstCoord dst_coord_;
1233};
1234
1235// Assume:
1236// 1. src:
1237// 1. SrcDesc is known at compile-time
1238// 2. SrcBuffer is DynamicBuffer
1239// 3. src_ref_idx is known at run-time
1240// 4. SrcRefToOriginDisplacement is known at compile-time
1241// 5. use #-step
1242// 2. dst:
1243// 1. DstDesc is known at compile-time
1244// 2. DstBuffer is StaticBuffer
1245// 3. DstOriginIdx is known at compile-time
1246// 4. use direct address calculation
1247// 3. vector access on src
1248template <typename SrcData,
1249 typename DstData,
1250 typename SrcDesc,
1251 typename DstDesc,
1252 typename SliceLengths,
1253 typename DimAccessOrder,
1254 index_t SrcVectorDim,
1255 index_t SrcScalarPerVector,
1256 index_t SrcScalarStrideInVector,
1257 typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1258 bool>::type = false>
1260{
1261 static constexpr index_t nDim = SliceLengths::Size();
1262
1264
1265 using SrcCoord = decltype(make_tensor_coordinate(SrcDesc{}, Index{}));
1266
1267 using SrcCoordStep = decltype(make_tensor_coordinate_step(SrcDesc{}, Index{}));
1268
1269 static constexpr index_t PackedSize = []() {
1271 return 2;
1272 else
1273 return 1;
1274 }();
1275
1276 __device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index& src_ref_idx)
1277 : src_ref_coord_(make_tensor_coordinate(SrcDesc{}, src_ref_idx))
1278 {
1279 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1280 "wrong! SrcDesc and DstDesc need to known at compile-time");
1281
1284 {
1285 static_assert(SrcScalarPerVector % PackedSize == 0, "pk data N cannot be 1");
1286 }
1287 }
1288
1289 template <typename SrcRefToOriginDisplacement,
1290 typename DstOriginIdx,
1291 typename SrcBuffer,
1292 typename DstBuffer>
1293 __device__ void Run(const SrcDesc&,
1294 const SrcRefToOriginDisplacement&,
1295 const SrcBuffer& src_buf,
1296 const DstDesc&,
1297 const DstOriginIdx&,
1298 DstBuffer& dst_buf) const
1299 {
1300 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1301 "wrong! SrcDesc and DstDesc need to known at compile-time");
1302
1303 static_assert(
1306 "wrong! SrcBuffer or DstBuffer data type is wrong");
1307
1308 static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
1309
1312 "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
1313 "at compile-time");
1314
1315 // SrcDesc and DstDesc are known at compile-time
1316 constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
1317 constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
1318
1319 // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
1320 constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
1321 constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{});
1322
1323 // scalar per access of each dim
1324 constexpr auto src_scalar_per_access = generate_sequence_v2(
1325 [&](auto i) constexpr {
1326 if constexpr(i == SrcVectorDim)
1327 {
1329 }
1330 else
1331 {
1332 return Number<1>{};
1333 }
1334 },
1335 Number<nDim>{});
1336
1337 // scalar step (if steping on SrcVectorDim) of each dim
1338 constexpr auto src_scalar_step_in_vector = generate_sequence_v2(
1339 [&](auto i) constexpr {
1340 if constexpr(i == SrcVectorDim)
1341 {
1342 return Number<1>{};
1343 }
1344 else
1345 {
1346 return Number<0>{};
1347 }
1348 },
1349 Number<nDim>{});
1350
1351 constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
1352
1353 constexpr auto dim_access_order = DimAccessOrder{};
1354
1355 constexpr auto ordered_access_lengths =
1356 container_reorder_given_new2old(access_lengths, dim_access_order);
1357
1358 static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
1359#if 0
1360 // TODO: unable to compile
1361 // position in slice window
1362 constexpr auto data_to_origin_disp_idx =
1363 container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
1364 src_scalar_per_access;
1365#else
1366 // position in slice window
1367 constexpr auto data_to_origin_disp_idx =
1368 ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access;
1369#endif
1370 // src coordinate
1371 constexpr auto src_ref_to_data_disp_idx =
1372 src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
1373
1374 constexpr auto src_ref_to_data_disp_coord_step =
1375 make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
1376
1377 auto src_data_coord = src_ref_coord_;
1378
1379 move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
1380
1381 vector_type_maker_t<SrcData, SrcScalarPerVector / PackedSize> src_tmp_vector;
1382
1383 using src_vector_t = typename decltype(src_tmp_vector)::type;
1384
1386 src_desc, src_data_coord);
1387
1388 // copy data from src_buf into src_tmp_vector
1389 if constexpr(SrcBuffer::IsDynamicBuffer())
1390 {
1391 src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
1392 src_buf.template Get<src_vector_t>(src_data_coord.GetOffset() / PackedSize,
1393 is_src_valid);
1394 }
1395 else if constexpr(SrcBuffer::IsStaticBuffer())
1396 {
1398 constexpr index_t src_offset = src_desc.CalculateOffset(
1399 src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
1400 i * src_scalar_step_in_vector);
1401
1402 src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
1403 });
1404 }
1405
1407 {
1408 // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
1409 // DstData)
1411
1412 constexpr index_t pack_size = 8;
1413
1414 static_assert(SrcScalarPerVector % pack_size == 0, "");
1415
1416 using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
1417 using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
1418
1419 static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
1421 dst_tmp_vector.template AsType<dst_v_t>()(i),
1422 src_tmp_vector.template AsType<src_v_t>()[i]);
1423 });
1424
1425 // copy data from dst_tmp_vector into dst_buf
1427 constexpr index_t dst_offset = dst_desc.CalculateOffset(
1428 dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
1429
1430 dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
1431 });
1432 }
1433 else if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
1435 SrcScalarPerVector % 2 == 0)
1436 {
1437 // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
1438 // DstData)
1440
1441 constexpr index_t pack_size = 2;
1442
1443 using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
1444 using src_v_t = typename vector_type_maker_t<SrcData, pack_size>::type;
1445 static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
1447 dst_tmp_vector.template AsType<dst_v_t>()(i),
1448 src_tmp_vector.template AsType<src_v_t>()[i]);
1449 });
1450
1451 // copy data from dst_tmp_vector into dst_buf
1453 constexpr index_t dst_offset = dst_desc.CalculateOffset(
1454 dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
1455
1456 dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
1457 });
1458 }
1459 else
1460 {
1461 // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
1462 // DstData)
1463 vector_type_maker_t<DstData, SrcScalarPerVector / PackedSize> dst_tmp_vector;
1464
1465 // TODO: if SrcData and DstData are vetor type, then static_cast may not compile
1466 static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
1467 dst_tmp_vector.template AsType<DstData>()(i) =
1468 type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
1469 });
1470
1471 // copy data from dst_tmp_vector into dst_buf
1472 static_for<0, SrcScalarPerVector / PackedSize, 1>{}([&](auto i) {
1473 constexpr index_t dst_offset = dst_desc.CalculateOffset(
1474 dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
1475
1476 dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
1477 });
1478 }
1479 });
1480 }
1481
1482 // Fuse scale
1483 template <typename SrcRefToOriginDisplacement,
1484 typename DstOriginIdx,
1485 typename SrcBuffer,
1486 typename DstBuffer>
1487 __device__ void Run(const SrcDesc&,
1488 const SrcRefToOriginDisplacement&,
1489 const SrcBuffer& src_buf,
1490 const DstData& scale,
1491 const DstDesc&,
1492 const DstOriginIdx&,
1493 DstBuffer& dst_buf) const
1494 {
1495 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1496 "wrong! SrcDesc and DstDesc need to known at compile-time");
1497
1498 static_assert(
1501 "wrong! SrcBuffer or DstBuffer data type is wrong");
1502
1503 static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
1504
1507 "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
1508 "at compile-time");
1509
1510 // SrcDesc and DstDesc are known at compile-time
1511 constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
1512 constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
1513
1514 // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
1515 constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
1516 constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{});
1517
1518 // scalar per access of each dim
1519 constexpr auto src_scalar_per_access = generate_sequence_v2(
1520 [&](auto i) constexpr {
1521 if constexpr(i == SrcVectorDim)
1522 {
1524 }
1525 else
1526 {
1527 return Number<1>{};
1528 }
1529 },
1530 Number<nDim>{});
1531
1532 // scalar step (if steping on SrcVectorDim) of each dim
1533 constexpr auto src_scalar_step_in_vector = generate_sequence_v2(
1534 [&](auto i) constexpr {
1535 if constexpr(i == SrcVectorDim)
1536 {
1537 return Number<1>{};
1538 }
1539 else
1540 {
1541 return Number<0>{};
1542 }
1543 },
1544 Number<nDim>{});
1545
1546 constexpr auto access_lengths = SliceLengths{} / src_scalar_per_access;
1547
1548 constexpr auto dim_access_order = DimAccessOrder{};
1549
1550 constexpr auto ordered_access_lengths =
1551 container_reorder_given_new2old(access_lengths, dim_access_order);
1552
1553 static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
1554#if 0
1555 // TODO: unable to compile
1556 // position in slice window
1557 constexpr auto data_to_origin_disp_idx =
1558 container_reorder_given_old2new(ordered_access_idx, dim_access_order) *
1559 src_scalar_per_access;
1560#else
1561 // position in slice window
1562 constexpr auto data_to_origin_disp_idx =
1563 ordered_access_idx.ReorderGivenOld2New(dim_access_order) * src_scalar_per_access;
1564#endif
1565 // src coordinate
1566 constexpr auto src_ref_to_data_disp_idx =
1567 src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
1568
1569 constexpr auto src_ref_to_data_disp_coord_step =
1570 make_tensor_coordinate_step(src_desc, src_ref_to_data_disp_idx);
1571
1572 auto src_data_coord = src_ref_coord_;
1573
1574 move_tensor_coordinate(src_desc, src_data_coord, src_ref_to_data_disp_coord_step);
1575
1576 vector_type_maker_t<SrcData, SrcScalarPerVector / PackedSize> src_tmp_vector;
1577
1578 using src_vector_t = typename decltype(src_tmp_vector)::type;
1579
1581 src_desc, src_data_coord);
1582
1583 // copy data from src_buf into src_tmp_vector
1584 if constexpr(SrcBuffer::IsDynamicBuffer())
1585 {
1586 src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
1587 src_buf.template Get<src_vector_t>(src_data_coord.GetOffset() / PackedSize,
1588 is_src_valid);
1589 }
1590 else if constexpr(SrcBuffer::IsStaticBuffer())
1591 {
1593 constexpr index_t src_offset = src_desc.CalculateOffset(
1594 src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
1595 i * src_scalar_step_in_vector);
1596
1597 src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
1598 });
1599 }
1600
1602 {
1603 // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
1604 // DstData)
1606 vector_type<DstData, 2> scale_vector;
1607 scale_vector.template AsType<DstData>()(Number<0>{}) = scale;
1608 scale_vector.template AsType<DstData>()(Number<1>{}) = scale;
1609
1610 constexpr index_t pack_size = 8;
1611
1612 static_assert(SrcScalarPerVector % pack_size == 0, "");
1613
1614 using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
1615 using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
1616 using scale_v_t = typename vector_type_maker_t<DstData, 2>::type;
1617
1618 static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
1620 dst_tmp_vector.template AsType<dst_v_t>()(i),
1621 src_tmp_vector.template AsType<src_v_t>()[i],
1622 scale_vector.template AsType<scale_v_t>()[Number<0>{}]);
1623 });
1624
1625 // copy data from dst_tmp_vector into dst_buf
1627 constexpr index_t dst_offset = dst_desc.CalculateOffset(
1628 dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
1629
1630 dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
1631 });
1632 }
1633 else if constexpr(is_same<remove_cvref_t<SrcData>, f8_t>::value &&
1635 SrcScalarPerVector % 2 == 0)
1636 {
1637 // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
1638 // DstData)
1640
1641 constexpr index_t pack_size = 2;
1642
1643 using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
1644 using src_v_t = typename vector_type_maker_t<SrcData, pack_size>::type;
1645 static_for<0, SrcScalarPerVector / pack_size, 1>{}([&](auto i) {
1647 dst_tmp_vector.template AsType<dst_v_t>()(i),
1648 src_tmp_vector.template AsType<src_v_t>()[i]);
1649 });
1650
1651 // copy data from dst_tmp_vector into dst_buf
1653 constexpr index_t dst_offset = dst_desc.CalculateOffset(
1654 dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
1655
1656 dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
1657 });
1658 }
1659 else
1660 {
1661 // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
1662 // DstData)
1664
1665 // TODO: if SrcData and DstData are vetor type, then static_cast may not compile
1667 dst_tmp_vector.template AsType<DstData>()(i) =
1668 type_convert<DstData>(src_tmp_vector.template AsType<SrcData>()[i]);
1669 });
1670
1671 // copy data from dst_tmp_vector into dst_buf
1673 constexpr index_t dst_offset = dst_desc.CalculateOffset(
1674 dst_origin_idx + data_to_origin_disp_idx + i * src_scalar_step_in_vector);
1675
1676 dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
1677 });
1678 }
1679 });
1680 }
1681
1682 template <typename SrcSliceMoveStepIdx>
1683 __device__ void MoveSrcSliceWindow(const SrcDesc&,
1684 const SrcSliceMoveStepIdx& src_slice_move_step_idx)
1685 {
1686 constexpr auto src_desc = SrcDesc{};
1687
1688 const auto src_slice_move_step_iter =
1689 make_tensor_coordinate_step(src_desc, to_multi_index(src_slice_move_step_idx));
1690
1691 move_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
1692 }
1693 __device__ void SetSrcCoord(const Index& src_ref_idx)
1694 {
1695 src_ref_coord_ = make_tensor_coordinate(SrcDesc{}, src_ref_idx);
1696 }
1697
1698 private:
1699 SrcCoord src_ref_coord_;
1700};
1701
1708template <typename SrcData,
1709 typename DstData,
1710 typename SrcDesc,
1711 typename DstDesc,
1712 typename ElementwiseOperation,
1713 typename SliceLengths,
1714 typename DimAccessOrder,
1715 index_t DstVectorDim,
1716 index_t DstScalarPerVector,
1717 typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1718 bool>::type = false>
1720{
1721 static constexpr index_t nDim = SliceLengths::Size();
1722
1724
1725 static constexpr index_t PackedSize = []() {
1727 return 2;
1728 else
1729 return 1;
1730 }();
1731
1733 const ElementwiseOperation& element_op)
1734 : element_op_{element_op}
1735 {
1736 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1737 "wrong! Desc need to known at compile-time");
1738
1739 static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
1740 "wrong! Not divisible");
1741 }
1742
1743 template <typename SrcSliceOriginIdx,
1744 typename DstSliceOriginIdx,
1745 typename SrcBuffer,
1746 typename DstBuffer>
1747 __device__ void Run(const SrcDesc&,
1748 const SrcSliceOriginIdx&,
1749 const SrcBuffer& src_buf,
1750 const DstDesc&,
1751 const DstSliceOriginIdx&,
1752 DstBuffer& dst_buf) const
1753 {
1754 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1755 "wrong! Desc need to known at compile-time");
1756
1759 "wrong! SliceOrigin need to known at compile-time");
1760
1761 static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(),
1762 "wrong! Buffer need to be StaticBuffer");
1763
1764 // SrcDesc and src_slice_origin_idx are known at compile-time
1765 constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
1766 constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
1767 constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
1768 constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{});
1769
1770 // scalar per access on each dim
1771 constexpr auto dst_scalar_per_access = generate_sequence(
1773
1774 constexpr auto dst_scalar_step_in_vector =
1776
1777 using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
1778 DimAccessOrder,
1779 remove_cv_t<decltype(dst_scalar_per_access)>>;
1780
1781 static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
1782 "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
1783
1784 constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
1785
1787 {
1788 static_for<0, num_access, 1>{}([&](auto idx_1d) {
1789 typename vector_type_maker<SrcData, DstScalarPerVector / PackedSize>::type
1790 src_tmp_vector;
1791
1792 constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
1793
1794 // copy data from src_buf into dst_vector
1795 static_for<0, DstScalarPerVector / PackedSize, 1>{}([&](auto i) {
1796 constexpr index_t src_offset = src_desc.CalculateOffset(
1797 src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
1798
1799 src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
1800 });
1801
1802 // copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
1803 // DstData)
1805
1806 constexpr index_t pack_size = 8;
1807
1808 static_assert(DstScalarPerVector % pack_size == 0, "");
1809
1810 using src_v_t = typename vector_type_maker_t<SrcData, pack_size / PackedSize>::type;
1811 using dst_v_t = typename vector_type_maker_t<DstData, pack_size>::type;
1812
1813 static_for<0, DstScalarPerVector / pack_size, 1>{}([&](auto i) {
1815 dst_tmp_vector.template AsType<dst_v_t>()(i),
1816 src_tmp_vector.template AsType<src_v_t>()[i]);
1817 });
1818
1819 // copy data from dst_tmp_vector into dst_buf
1821 constexpr index_t dst_offset = dst_desc.CalculateOffset(
1822 dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
1823
1824 dst_buf(Number<dst_offset>{}) = dst_tmp_vector.template AsType<DstData>()[i];
1825 });
1826 });
1827 }
1828 else
1829 {
1830 static_for<0, num_access, 1>{}([&](auto idx_1d) {
1831 constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
1832
1833 // copy data from src_buf into dst_vector
1835 constexpr index_t src_offset = src_desc.CalculateOffset(
1836 src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
1837
1838 constexpr index_t dst_offset = dst_desc.CalculateOffset(
1839 dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
1840
1841 DstData v;
1842
1843 // apply element-wise operation
1844 element_op_(v, src_buf[Number<src_offset>{}]);
1845
1846 // apply type convert
1847 dst_buf(Number<dst_offset>{}) = v;
1848 });
1849 });
1850 }
1851 }
1852
1853 ElementwiseOperation element_op_;
1854};
1855
1856// Specialized for gfx11
1857// A single Wave32 is composed by double row
1858// Data exchange allowed between these two rows
1859// This RowLane Dst buf will be filled from two Src buf
1860// SrcA: From specific thread buffer hold by This RowLane on This Row
1861// SrcB: From specific thread buffer hold by This RowLane on The other Row
1862template <typename SrcData,
1863 typename DstData,
1864 typename SrcDesc,
1865 typename DstDesc,
1866 typename ElementwiseOperation,
1867 typename SliceLengths,
1868 typename DimAccessOrder,
1869 index_t DstVectorDim,
1870 index_t DstScalarPerVector,
1871 uint32_t LowEightRowlaneIdx,
1872 uint32_t HighEightRowLaneIdx,
1873 bool IntraRowSwizzlePerm,
1874 typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1875 bool>::type = false>
1877{
1878 static constexpr index_t nDim = SliceLengths::Size();
1879
1881
1883 {
1884 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1885 "wrong! Desc need to known at compile-time");
1886
1887 static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
1888 "wrong! Not divisible");
1889 ignore = src_idx;
1890 }
1891
1892 template <typename SrcSliceOriginIdx,
1893 typename DstSliceOriginIdx,
1894 typename SrcBuffer,
1895 typename DstBuffer>
1896 __device__ void Run(const SrcDesc&,
1897 const SrcSliceOriginIdx&,
1898 const SrcBuffer& src_buf,
1899 const DstDesc&,
1900 const DstSliceOriginIdx&,
1901 DstBuffer& dst_buf) const
1902 {
1903 ElementwiseOperation element_op_{};
1904 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
1905 "wrong! Desc need to known at compile-time");
1906
1909 "wrong! SliceOrigin need to known at compile-time");
1910
1911 static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(),
1912 "wrong! Buffer need to be StaticBuffer");
1913
1914 // SrcDesc and src_slice_origin_idx are known at compile-time
1915 constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
1916 constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
1917 constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
1918 constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{});
1919
1920 // scalar per access on each dim
1921 constexpr auto dst_scalar_per_access = generate_sequence(
1923
1924 constexpr auto dst_scalar_step_in_vector =
1926
1927 using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
1928 DimAccessOrder,
1929 remove_cv_t<decltype(dst_scalar_per_access)>>;
1930
1931 static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
1932 "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
1933
1934 constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
1935
1936 static_for<0, num_access, 1>{}([&](auto idx_1d) {
1937 constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
1938
1939 // copy data from src_buf into dst_vector
1941 // src_desc error, non constexpr, caused by merge transform
1942 constexpr index_t src_offset = src_desc.CalculateOffset(
1943 src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
1944
1945 constexpr index_t dst_offset = dst_desc.CalculateOffset(
1946 dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
1947
1948 SrcData v_this_row, v_theother_row;
1949 // int type temp value due to intrinsic requirement
1950 int temp = 0;
1951
1952 // apply element-wise operation
1953 element_op_(v_this_row, src_buf[Number<src_offset>{}]);
1954
1955 // apply intra-row permute.
1956 if constexpr(IntraRowSwizzlePerm)
1957 {
1958 temp = __builtin_amdgcn_permlane16(
1959 temp, type_convert_sp<int>(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0);
1960 v_this_row = type_convert_sp<SrcData>(temp);
1961 }
1962
1963 // apply inter-row permute.
1964 temp = __builtin_amdgcn_permlanex16(temp,
1965 type_convert_sp<int>(v_this_row),
1966 LowEightRowlaneIdx,
1967 HighEightRowLaneIdx,
1968 1,
1969 0);
1970 v_theother_row = type_convert_sp<SrcData>(temp);
1971
1972 if(get_thread_local_1d_id() % 32 < 16)
1973 {
1974 // apply type convert
1975 dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_this_row);
1977 type_convert_sp<DstData>(v_theother_row);
1978 }
1979 else
1980 {
1981 // apply type convert
1983 type_convert_sp<DstData>(v_this_row);
1984 dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_theother_row);
1985 }
1986 });
1987 });
1988 }
1989};
1990
1991// Specialized for gfx12
1992template <typename SrcData,
1993 typename DstData,
1994 typename SrcDesc,
1995 typename DstDesc,
1996 typename ElementwiseOperation,
1997 typename SliceLengths,
1998 typename DimAccessOrder,
1999 index_t DstVectorDim,
2000 index_t DstScalarPerVector,
2001 bool IntraRowSwizzlePerm,
2002 typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
2003 bool>::type = false>
2005{
2006 static constexpr index_t nDim = SliceLengths::Size();
2007
2009
2011 {
2012 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
2013 "wrong! Desc need to known at compile-time");
2014
2015 static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
2016 "wrong! Not divisible");
2017 ignore = src_idx;
2018 }
2019
2020 template <typename SrcSliceOriginIdx,
2021 typename DstSliceOriginIdx,
2022 typename SrcBuffer,
2023 typename DstBuffer>
2024 __device__ void Run(const SrcDesc&,
2025 const SrcSliceOriginIdx&,
2026 const SrcBuffer& src_buf,
2027 const DstDesc&,
2028 const DstSliceOriginIdx&,
2029 DstBuffer& dst_buf) const
2030 {
2031 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
2032 "wrong! Desc need to known at compile-time");
2033
2036 "wrong! SliceOrigin need to known at compile-time");
2037
2038 static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(),
2039 "wrong! Buffer need to be StaticBuffer");
2040
2041 // SrcDesc and src_slice_origin_idx are known at compile-time
2042 constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
2043 constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
2044 constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
2045 constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{});
2046
2047 // scalar per access on each dim
2048 constexpr auto dst_scalar_per_access = generate_sequence(
2050
2051 constexpr auto dst_scalar_step_in_vector =
2053
2054 using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
2055 DimAccessOrder,
2056 remove_cv_t<decltype(dst_scalar_per_access)>>;
2057
2058 static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
2059 "wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
2060
2061 constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
2062
2063 static_for<0, num_access, 1>{}([&](auto idx_1d) {
2064 constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
2065
2066 // copy data from src_buf into dst_vector
2068 // src_desc error, non constexpr, caused by merge transform
2069 constexpr index_t src_offset = src_desc.CalculateOffset(
2070 src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
2071
2072 constexpr index_t dst_offset = dst_desc.CalculateOffset(
2073 dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
2074
2075 SrcData v_this_row;
2076 // int type temp value due to intrinsic requirement
2077 int temp = 0;
2078
2079 // apply element-wise operation
2080 element_op_(v_this_row, src_buf[Number<src_offset>{}]);
2081
2082 // apply intra-row permute.
2083 if constexpr(IntraRowSwizzlePerm)
2084 {
2085 temp = __builtin_amdgcn_permlane16(
2086 temp, type_convert_sp<int>(v_this_row), 0xb3a29180, 0xf7e6d5c4, 1, 0);
2087 v_this_row = type_convert_sp<SrcData>(temp);
2088 }
2089
2090 // apply type convert
2091 dst_buf(Number<dst_offset>{}) = type_convert_sp<DstData>(v_this_row);
2092 });
2093 });
2094 }
2095 ElementwiseOperation element_op_{};
2096};
2097
2098} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
f8_fnuz_t f8_t
Definition amd_ck_fp8.hpp:1762
__host__ __device__ constexpr Y type_convert_sp(X x)
Definition utility/type_convert.hpp:205
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto make_tensor_coordinate_step(const TensorDesc &, const VisibleIndex &idx_diff_visible, UpdateLowerIndexHack)
Definition tensor_description/tensor_descriptor.hpp:444
__host__ __device__ constexpr void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const TensorCoordStep &coord_step)
Definition tensor_description/tensor_descriptor.hpp:508
InMemoryDataOperationEnum
Definition ck.hpp:277
@ Set
Definition ck.hpp:278
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr bool coordinate_has_valid_offset_assuming_visible_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition tensor_description/tensor_descriptor.hpp:560
_Float16 half_t
Definition data_type.hpp:31
integral_constant< index_t, N > Number
Definition number.hpp:12
@ Lds
Definition amd_address_space.hpp:18
@ Global
Definition amd_address_space.hpp:17
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence< Is... >)
Definition utility/container_helper.hpp:380
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto generate_sequence_v2(F &&f, Number< N >)
Definition sequence_helper.hpp:25
std::enable_if< B, T > enable_if
Definition enable_if.hpp:24
__host__ __device__ constexpr auto container_reorder_given_old2new(const Array< TData, NSize > &old_array, Sequence< IRs... > old2new)
Definition utility/container_helper.hpp:54
__host__ __device__ constexpr auto to_multi_index(const T &x)
Definition array_multi_index.hpp:28
constexpr bool is_same_v
Definition type.hpp:283
__host__ __device__ constexpr auto generate_tuple(F &&f, Number< N >)
Definition tuple_helper.hpp:21
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__host__ __device__ constexpr auto generate_sequence(F, Number< N >)
Definition sequence_helper.hpp:18
typename remove_cv< T >::type remove_cv_t
Definition type.hpp:295
__host__ __device__ constexpr auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition tensor_description/tensor_descriptor.hpp:407
__host__ __device__ constexpr auto container_reorder_given_new2old(const Array< TData, NSize > &old_array, Sequence< IRs... >)
Definition utility/container_helper.hpp:43
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
typename vector_type_maker< T, N >::type vector_type_maker_t
Definition dtype_vector.hpp:54
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
unsigned int uint32_t
Definition stdint.h:126
__host__ static __device__ constexpr T QuietNaN()
Definition numeric_limits.hpp:313
Definition tensor_space_filling_curve.hpp:20
static __device__ __host__ constexpr auto GetStepBetween(Number< AccessIdx1dBegin >, Number< AccessIdx1dEnd >)
Definition tensor_space_filling_curve.hpp:52
__host__ static __device__ constexpr index_t GetNumOfAccess()
Definition tensor_space_filling_curve.hpp:41
static __device__ __host__ constexpr Index GetIndex(Number< AccessIdx1d >)
Definition tensor_space_filling_curve.hpp:81
static constexpr index_t ScalarPerVector
Definition tensor_space_filling_curve.hpp:25
static __device__ __host__ constexpr auto GetForwardStep(Number< AccessIdx1d >)
Definition tensor_space_filling_curve.hpp:66
MultiIndex< nDim > Index
Definition tensor_space_filling_curve.hpp:23
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf) const
Definition threadwise_tensor_slice_transfer.hpp:1896
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow(const Index &src_idx)
Definition threadwise_tensor_slice_transfer.hpp:1882
static constexpr index_t nDim
Definition threadwise_tensor_slice_transfer.hpp:2006
MultiIndex< nDim > Index
Definition threadwise_tensor_slice_transfer.hpp:2008
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow(const Index &src_idx)
Definition threadwise_tensor_slice_transfer.hpp:2010
ElementwiseOperation element_op_
Definition threadwise_tensor_slice_transfer.hpp:2095
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf) const
Definition threadwise_tensor_slice_transfer.hpp:2024
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf) const
Definition threadwise_tensor_slice_transfer.hpp:1747
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic(const ElementwiseOperation &element_op)
Definition threadwise_tensor_slice_transfer.hpp:1732
static __device__ constexpr auto GetDstCoordinateResetStep()
Definition threadwise_tensor_slice_transfer.hpp:149
static constexpr index_t nDim
Definition threadwise_tensor_slice_transfer.hpp:40
MultiIndex< nDim > Index
Definition threadwise_tensor_slice_transfer.hpp:42
decltype(make_tensor_coordinate(DstDesc{}, Index{})) DstCoord
Definition threadwise_tensor_slice_transfer.hpp:44
__device__ constexpr ThreadwiseTensorSliceTransfer_v1r3(const DstDesc &dst_desc, const Index &dst_slice_origin_idx, const ElementwiseOperation &element_op)
Definition threadwise_tensor_slice_transfer.hpp:48
decltype(make_tensor_coordinate_step(DstDesc{}, Index{})) DstCoordStep
Definition threadwise_tensor_slice_transfer.hpp:46
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &dst_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer.hpp:173
__device__ void SetDstSliceOrigin(const DstDesc &dst_desc, const Index &dst_slice_origin_idx)
Definition threadwise_tensor_slice_transfer.hpp:60
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer.hpp:66
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer.hpp:493
static constexpr index_t PackedSize
Definition threadwise_tensor_slice_transfer.hpp:453
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx, const SrcMoveSliceWindowStepHack &src_move_slice_window_step_hack)
Definition threadwise_tensor_slice_transfer.hpp:634
__device__ void SetSrcSliceOrigin(const SrcDesc &src_desc, const Index &src_slice_origin_idx)
Definition threadwise_tensor_slice_transfer.hpp:478
decltype(make_tensor_coordinate(SrcDesc{}, Index{})) SrcCoord
Definition threadwise_tensor_slice_transfer.hpp:449
MultiIndex< nDim > Index
Definition threadwise_tensor_slice_transfer.hpp:447
decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})) SrcCoordStep
Definition threadwise_tensor_slice_transfer.hpp:451
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer.hpp:617
static __device__ constexpr auto GetSrcCoordinateResetStep()
Definition threadwise_tensor_slice_transfer.hpp:593
static constexpr index_t nDim
Definition threadwise_tensor_slice_transfer.hpp:445
__device__ constexpr ThreadwiseTensorSliceTransfer_v2_gather(const SrcDesc &src_desc, const Index &src_slice_origin_idx, const StaticallyIndexedArray< index_t, scale_gather_num > &scale_gather_offsets)
Definition threadwise_tensor_slice_transfer.hpp:460
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer.hpp:276
MultiIndex< nDim > Index
Definition threadwise_tensor_slice_transfer.hpp:241
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer.hpp:389
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx, const SrcMoveSliceWindowStepHack &src_move_slice_window_step_hack)
Definition threadwise_tensor_slice_transfer.hpp:406
static __device__ constexpr auto GetSrcCoordinateResetStep()
Definition threadwise_tensor_slice_transfer.hpp:365
__device__ void SetSrcSliceOrigin(const SrcDesc &src_desc, const Index &src_slice_origin_idx)
Definition threadwise_tensor_slice_transfer.hpp:270
static constexpr index_t nDim
Definition threadwise_tensor_slice_transfer.hpp:239
decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})) SrcCoordStep
Definition threadwise_tensor_slice_transfer.hpp:245
__device__ constexpr ThreadwiseTensorSliceTransfer_v2(const SrcDesc &src_desc, const Index &src_slice_origin_idx)
Definition threadwise_tensor_slice_transfer.hpp:254
static constexpr index_t PackedSize
Definition threadwise_tensor_slice_transfer.hpp:247
decltype(make_tensor_coordinate(SrcDesc{}, Index{})) SrcCoord
Definition threadwise_tensor_slice_transfer.hpp:243
__device__ constexpr ThreadwiseTensorSliceTransfer_v3(const SrcDesc &src_desc, const Index &src_slice_origin, const DstDesc &dst_desc, const Index &dst_slice_origin)
Definition threadwise_tensor_slice_transfer.hpp:691
static __device__ constexpr auto GetDstCoordinateResetStep()
Definition threadwise_tensor_slice_transfer.hpp:1115
static __device__ constexpr auto GetSrcCoordinateResetStep()
Definition threadwise_tensor_slice_transfer.hpp:1055
decltype(make_tensor_coordinate(DstDesc{}, Index{})) DstCoord
Definition threadwise_tensor_slice_transfer.hpp:686
decltype(make_tensor_coordinate_step(SrcDesc{}, Index{})) SrcCoordStep
Definition threadwise_tensor_slice_transfer.hpp:688
MultiIndex< nDim > Index
Definition threadwise_tensor_slice_transfer.hpp:683
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx, const SrcMoveSliceWindowStepHack &src_move_slice_window_step_hack)
Definition threadwise_tensor_slice_transfer.hpp:1193
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf, const SrcStepHacks &src_step_hacks)
Definition threadwise_tensor_slice_transfer.hpp:716
decltype(make_tensor_coordinate_step(DstDesc{}, Index{})) DstCoordStep
Definition threadwise_tensor_slice_transfer.hpp:689
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &dst_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer.hpp:1209
__device__ void SetDstSliceOrigin(const DstDesc &dst_desc, const Index &dst_slice_origin_idx)
Definition threadwise_tensor_slice_transfer.hpp:709
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition threadwise_tensor_slice_transfer.hpp:1042
__device__ void SetSrcSliceOrigin(const SrcDesc &src_desc, const Index &src_slice_origin_idx)
Definition threadwise_tensor_slice_transfer.hpp:704
decltype(make_tensor_coordinate(SrcDesc{}, Index{})) SrcCoord
Definition threadwise_tensor_slice_transfer.hpp:685
static constexpr index_t nDim
Definition threadwise_tensor_slice_transfer.hpp:682
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf)
Definition threadwise_tensor_slice_transfer.hpp:1028
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition threadwise_tensor_slice_transfer.hpp:1176
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf, const DstStepHacks &dst_step_hacks)
Definition threadwise_tensor_slice_transfer.hpp:871
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition threadwise_tensor_slice_transfer.hpp:1293
__device__ constexpr ThreadwiseTensorSliceTransfer_v4(const Index &src_ref_idx)
Definition threadwise_tensor_slice_transfer.hpp:1276
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstData &scale, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition threadwise_tensor_slice_transfer.hpp:1487
__device__ void SetSrcCoord(const Index &src_ref_idx)
Definition threadwise_tensor_slice_transfer.hpp:1693
__device__ void MoveSrcSliceWindow(const SrcDesc &, const SrcSliceMoveStepIdx &src_slice_move_step_idx)
Definition threadwise_tensor_slice_transfer.hpp:1683
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition threadwise_tensor_slice_transfer_util.hpp:20
Definition threadwise_tensor_slice_transfer_util.hpp:29
Definition data_type.hpp:42
Definition is_known_at_compile_time.hpp:14
Definition type.hpp:177
Definition data_type.hpp:187
Definition functional2.hpp:33
Definition functional3.hpp:97
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:269
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:307
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:159
typename sequence_gen< NSize, F >::type type
Definition utility/sequence.hpp:295
Definition dtype_vector.hpp:30
vector_type< T, N > type
Definition dtype_vector.hpp:31
Definition dtype_vector.hpp:10