23template <
typename ThreadGroup,
24 typename SrcElementwiseOperation,
25 typename DstElementwiseOperation,
27 typename BlockSliceLengths,
28 typename ThreadClusterLengths,
29 typename ThreadClusterArrangeOrder,
34 typename SrcDimAccessOrder,
35 typename DstDimAccessOrder,
40 index_t SrcScalarStrideInVector,
41 index_t DstScalarStrideInVector,
42 bool ThreadTransferSrcResetCoordinateAfterRun,
43 bool ThreadTransferDstResetCoordinateAfterRun,
54 const SrcDesc& src_desc,
55 const Index& src_block_slice_origin,
56 const SrcElementwiseOperation& src_element_op,
57 const DstDesc& dst_desc,
58 const Index& dst_block_slice_origin,
59 const DstElementwiseOperation& dst_element_op)
60 : threadwise_transfer_(src_desc,
70 nDim == ThreadClusterLengths::Size() &&
71 nDim == ThreadClusterArrangeOrder::Size() &&
72 nDim == SrcDimAccessOrder::Size() &&
nDim == DstDimAccessOrder::Size(),
73 "wrong! nDim not consistent");
77 "wrong! threads should be mapped to cover entire slicing window");
79 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
80 "wrong! ThreadGroup::GetNumOfThread() too small");
82 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
83 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
85 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
86 make_multi_index(ThreadGroup::GetThreadId()));
88 const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
90 threadwise_transfer_.SetSrcSliceOrigin(src_desc,
91 src_block_slice_origin + thread_data_idx_begin);
92 threadwise_transfer_.SetDstSliceOrigin(dst_desc,
93 dst_block_slice_origin + thread_data_idx_begin);
99 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
100 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
102 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
107 threadwise_transfer_.SetSrcSliceOrigin(src_desc,
108 src_block_slice_origin + thread_data_idx_begin);
112 template <
typename SeqIdx, index_t ThreadScratchId = 0>
118 template <
typename SrcBuffer, index_t ThreadScratchId = 0>
119 __device__
void RunRead(
const SrcDesc& src_desc,
120 const SrcBuffer& src_buf,
123 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
124 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
126 threadwise_transfer_.
RunRead(src_desc, src_buf, thread_scratch_id);
130 template <
typename DstBuffer, index_t ThreadScratchId = 0>
135 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
136 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
138 threadwise_transfer_.
RunWrite(dst_desc, dst_buf, thread_scratch_id);
142 template <
typename SrcBuffer,
typename DstBuffer, index_t ThreadScratchId>
143 __device__
void Run(
const SrcDesc& src_desc,
144 const SrcBuffer& src_buf,
145 const DstDesc& dst_desc,
149 RunRead(src_desc, src_buf, thread_scratch_id);
150 RunWrite(dst_desc, dst_buf, thread_scratch_id);
155 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
156 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
158 threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
164 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
165 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
167 threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
172 static constexpr auto thread_cluster_desc_ =
175 using ThreadwiseTransfer =
177 SrcElementwiseOperation,
178 DstElementwiseOperation,
190 SrcScalarStrideInVector,
191 DstScalarStrideInVector,
192 ThreadTransferSrcResetCoordinateAfterRun,
193 ThreadTransferDstResetCoordinateAfterRun,
196 ThreadwiseTransfer threadwise_transfer_;
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_zero_multi_index()
Definition array_multi_index.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id)
Definition thread_group_tensor_slice_transfer_v4r1.hpp:143
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_v4r1.hpp:51
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_v4r1.hpp:47
static constexpr auto thread_slice_lengths
Definition thread_group_tensor_slice_transfer_v4r1.hpp:49
__device__ void SetSrcSliceOrigin(const SrcDesc &src_desc, const Index &src_block_slice_origin)
Definition thread_group_tensor_slice_transfer_v4r1.hpp:97
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition thread_group_tensor_slice_transfer_v4r1.hpp:119
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v4r1.hpp:153
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v4r1.hpp:162
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1(const SrcDesc &src_desc, const Index &src_block_slice_origin, const SrcElementwiseOperation &src_element_op, const DstDesc &dst_desc, const Index &dst_block_slice_origin, const DstElementwiseOperation &dst_element_op)
Definition thread_group_tensor_slice_transfer_v4r1.hpp:53
__device__ constexpr auto GetSrcThreadScratchIdx()
Definition thread_group_tensor_slice_transfer_v4r1.hpp:113
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition thread_group_tensor_slice_transfer_v4r1.hpp:131
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r1.hpp:118
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r1.hpp:521