thread_group_tensor_slice_transfer_v6r1.hpp Source File

thread_group_tensor_slice_transfer_v6r1.hpp Source File#

Composable Kernel: thread_group_tensor_slice_transfer_v6r1.hpp Source File
thread_group_tensor_slice_transfer_v6r1.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
11
12namespace ck {
13
14// this version does following things to avoid scratch memory issue
15// 1. Use StaticallyIndexedArray instead of C array for thread buffer
16// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
17// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
18template <typename ThreadGroup,
19 typename ElementwiseOperation,
21 typename SliceLengths,
22 typename ThreadClusterLengths,
23 typename ThreadClusterArrangeOrder,
24 typename SrcData,
25 typename DstData,
26 typename SrcDesc,
27 typename DstDesc,
28 typename DimAccessOrder,
29 index_t VectorDim,
30 index_t ScalarPerVector,
31 bool ThreadTransferSrcResetCoordinateAfterRun,
32 bool ThreadTransferDstResetCoordinateAfterRun>
34{
36
37 static constexpr auto thread_slice_lengths = SliceLengths{} / ThreadClusterLengths{};
38
40
41 __device__ constexpr ThreadGroupTensorSliceTransfer_v6r1(const SrcDesc& src_desc,
42 const Index& src_block_slice_origin,
43 const DstDesc& dst_desc,
44 const Index& dst_block_slice_origin,
45 const ElementwiseOperation& element_op)
46 : threadwise_transfer_(src_desc,
48 dst_desc,
50 element_op)
51
52 {
55 nDim == ThreadClusterLengths::Size() &&
56 nDim == ThreadClusterArrangeOrder::Size() &&
57 nDim == DimAccessOrder::Size(),
58 "wrong! nDim not consistent");
59
60 static_assert(
61 is_same<SliceLengths, decltype(thread_slice_lengths * ThreadClusterLengths{})>{},
62 "wrong! threads should be mapped to cover entire slicing window");
63
64 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
65 "wrong! ThreadGroup::GetNumOfThread() too small");
66
67 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
68 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
69 {
70 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
71 make_multi_index(ThreadGroup::GetThreadId()));
72
73 const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
74
75 threadwise_transfer_.SetSrcSliceOrigin(src_desc,
76 src_block_slice_origin + thread_data_idx_begin);
77 threadwise_transfer_.SetDstSliceOrigin(dst_desc,
78 dst_block_slice_origin + thread_data_idx_begin);
79 }
80 }
81
82 template <typename SrcBuffer, typename DstBuffer>
83 __device__ void Run(const SrcDesc& src_desc,
84 const SrcBuffer& src_buf,
85 const DstDesc& dst_desc,
86 DstBuffer& dst_buf)
87 {
88 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
89 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
90 {
91 threadwise_transfer_.Run(src_desc, src_buf, dst_desc, dst_buf);
92 }
93 }
94
95 __device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
96 {
97 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
98 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
99 {
100 threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
101 }
102 }
103
104 __device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
105 {
106 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
107 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
108 {
109 threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
110 }
111 }
112
113 private:
114 static constexpr auto thread_cluster_desc_ =
115 make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
116
117 using ThreadwiseTransfer =
118 ThreadwiseTensorSliceTransfer_v6r1<SrcData,
119 DstData,
120 SrcDesc,
121 DstDesc,
122 ElementwiseOperation,
123 decltype(thread_slice_lengths),
124 DimAccessOrder,
125 VectorDim,
126 ScalarPerVector,
127 DstInMemOp,
128 ThreadTransferSrcResetCoordinateAfterRun,
129 ThreadTransferDstResetCoordinateAfterRun>;
130
131 ThreadwiseTransfer threadwise_transfer_;
132};
133
134} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
__host__ __device__ constexpr auto make_zero_multi_index()
Definition array_multi_index.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_v6r1.hpp:39
static constexpr auto thread_slice_lengths
Definition thread_group_tensor_slice_transfer_v6r1.hpp:37
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r1.hpp:95
__device__ constexpr ThreadGroupTensorSliceTransfer_v6r1(const SrcDesc &src_desc, const Index &src_block_slice_origin, const DstDesc &dst_desc, const Index &dst_block_slice_origin, const ElementwiseOperation &element_op)
Definition thread_group_tensor_slice_transfer_v6r1.hpp:41
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v6r1.hpp:104
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_v6r1.hpp:35
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition thread_group_tensor_slice_transfer_v6r1.hpp:83
Definition type.hpp:177