27template <
typename ThreadGroup,
28 typename SrcElementwiseOperation,
29 typename ScaleElementwiseOperation,
30 typename DstElementwiseOperation,
32 typename BlockSliceLengths,
33 typename BlockScaleSliceLengths,
34 typename ThreadClusterLengths,
35 typename ThreadClusterArrangeOrder,
42 typename SrcDimAccessOrder,
43 typename DstDimAccessOrder,
49 index_t SrcScalarStrideInVector,
50 index_t ScaleScalarStrideInVector,
51 index_t DstScalarStrideInVector,
52 bool ThreadTransferSrcResetCoordinateAfterRun,
53 bool ThreadTransferDstResetCoordinateAfterRun,
61 BlockScaleSliceLengths{} / ThreadClusterLengths{};
66 const SrcDesc& src_desc,
67 const Index& src_block_slice_origin,
68 const SrcElementwiseOperation& src_element_op,
69 const ScaleDesc& scale_desc,
70 const Index& scale_block_slice_origin,
71 const ScaleElementwiseOperation& scale_element_op,
72 const DstDesc& dst_desc,
73 const Index& dst_block_slice_origin,
74 const DstElementwiseOperation& dst_element_op)
75 : threadwise_transfer_(src_desc,
89 nDim == ThreadClusterLengths::Size() &&
90 nDim == ThreadClusterArrangeOrder::Size() &&
91 nDim == SrcDimAccessOrder::Size() &&
nDim == DstDimAccessOrder::Size(),
92 "wrong! nDim not consistent");
98 "wrong! threads should be mapped to cover entire slicing window");
100 static_assert(ThreadGroup::GetNumOfThread() >= thread_cluster_desc_.GetElementSize(),
101 "wrong! ThreadGroup::GetNumOfThread() too small");
103 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
104 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
106 const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
107 make_multi_index(ThreadGroup::GetThreadId()));
109 const auto thread_data_idx_begin = thread_cluster_idx * thread_slice_lengths;
111 threadwise_transfer_.SetSrcSliceOrigin(src_desc,
112 src_block_slice_origin + thread_data_idx_begin);
113 threadwise_transfer_.SetScaleSliceOrigin(
114 scale_desc, scale_block_slice_origin + thread_data_idx_begin);
115 threadwise_transfer_.SetDstSliceOrigin(dst_desc,
116 dst_block_slice_origin + thread_data_idx_begin);
120 template <
typename SrcBuffer, index_t ThreadScratchId = 0>
121 __device__
void RunRead(
const SrcDesc& src_desc,
122 const SrcBuffer& src_buf,
125 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
126 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
128 threadwise_transfer_.
RunRead(src_desc, src_buf, thread_scratch_id);
133 template <
typename ScaleBuffer>
134 __device__
void RunScaleRead(
const ScaleDesc& scale_desc,
const ScaleBuffer& scale_buf)
136 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
137 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
139 threadwise_transfer_.RunScaleRead(scale_desc, scale_buf);
143 template <
typename DstBuffer, index_t ThreadScratchId = 0>
148 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
149 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
151 threadwise_transfer_.
RunWrite(dst_desc, dst_buf, thread_scratch_id);
171 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
172 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
174 threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
180 if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
181 ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
183 threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
190 static constexpr auto thread_cluster_desc_ =
193 using ThreadwiseTransfer =
196 SrcElementwiseOperation,
197 ScaleElementwiseOperation,
198 DstElementwiseOperation,
211 ScaleScalarPerVector,
213 SrcScalarStrideInVector,
214 ScaleScalarStrideInVector,
215 DstScalarStrideInVector,
216 ThreadTransferSrcResetCoordinateAfterRun,
217 ThreadTransferDstResetCoordinateAfterRun,
220 ThreadwiseTransfer threadwise_transfer_;
int32_t index_t
Definition ck.hpp:299
InMemoryDataOperationEnum
Definition ck.hpp:277
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr auto make_zero_multi_index()
Definition array_multi_index.hpp:21
typename remove_reference< T >::type remove_reference_t
Definition type.hpp:292
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition thread_group_tensor_slice_transfer_v4r1_dequant.hpp:121
static constexpr auto thread_slice_lengths
Definition thread_group_tensor_slice_transfer_v4r1_dequant.hpp:59
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition thread_group_tensor_slice_transfer_v4r1_dequant.hpp:144
static constexpr auto scale_thread_slice_lengths
Definition thread_group_tensor_slice_transfer_v4r1_dequant.hpp:60
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v4r1_dequant.hpp:178
__device__ void RunScaleRead(const ScaleDesc &scale_desc, const ScaleBuffer &scale_buf)
Definition thread_group_tensor_slice_transfer_v4r1_dequant.hpp:134
__device__ constexpr ThreadGroupTensorSliceTransfer_v4r1_dequant(const SrcDesc &src_desc, const Index &src_block_slice_origin, const SrcElementwiseOperation &src_element_op, const ScaleDesc &scale_desc, const Index &scale_block_slice_origin, const ScaleElementwiseOperation &scale_element_op, const DstDesc &dst_desc, const Index &dst_block_slice_origin, const DstElementwiseOperation &dst_element_op)
Definition thread_group_tensor_slice_transfer_v4r1_dequant.hpp:65
MultiIndex< nDim > Index
Definition thread_group_tensor_slice_transfer_v4r1_dequant.hpp:63
static constexpr index_t nDim
Definition thread_group_tensor_slice_transfer_v4r1_dequant.hpp:57
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &step)
Definition thread_group_tensor_slice_transfer_v4r1_dequant.hpp:169
__device__ void RunWrite(const DstDesc &dst_desc, DstBuffer &dst_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r1_dequant.hpp:548
__device__ void RunRead(const SrcDesc &src_desc, const SrcBuffer &src_buf, Number< ThreadScratchId > thread_scratch_id=Number< ThreadScratchId >{})
Definition threadwise_tensor_slice_transfer_v3r1_dequant.hpp:129