20 typename ComputeDataType,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
43 typename ComputeDataType,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
160 template <
bool HasMainLoop,
164 typename ABlockTransfer,
165 typename AGridBuffer,
166 typename ABlockBuffer,
167 typename ABlockTransferStep,
170 typename BBlockTransfer,
171 typename BGridBuffer,
172 typename BBlockBuffer,
173 typename BBlockTransferStep,
174 typename CThreadBuffer,
176 typename BScaleGridBuffer,
177 typename BScaleGridDesc,
178 typename BScaleThreadDesc,
179 typename BScaleThreadTransfer,
180 typename BScaleThreadTransferStep>
183 const AGridDesc& a_grid_desc,
184 const ABlockDesc& a_block_desc,
185 ABlockTransfer& a_blockwise_copy,
186 const AGridBuffer& a_grid_buf,
187 ABlockBuffer& a_block_buf,
188 const ABlockTransferStep& a_block_copy_step,
190 const BGridDesc& b_grid_desc,
191 const BBlockDesc& b_block_desc,
192 BBlockTransfer& b_blockwise_copy,
193 const BGridBuffer& b_grid_buf,
194 BBlockBuffer& b_block_buf,
195 const BBlockTransferStep& b_block_copy_step,
197 CThreadBuffer& c_thread_buf,
199 const BScaleGridDesc& b_scale_grid_desc,
200 const BScaleThreadDesc& b_scale_thread_desc,
201 BScaleThreadTransfer& b_scale_thread_copy,
202 const BScaleGridBuffer& b_scale_grid_buf,
203 const BScaleThreadTransferStep& b_scale_thread_copy_step,
206 index_t num_loop_per_scale)
const
209 ignore = num_loop_per_scale;
216 b_scale_thread_desc.GetElementSpaceSize());
219 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
220 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
222 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
223 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
226 b_scale_thread_copy.Run(b_scale_grid_desc,
232 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
233 b_scale_thread_copy_step.At(
Number<0>{}));
235 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
236 b_scale_thread_copy_step.At(
Number<1>{}));
239 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
240 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
243 c_thread_buf.Clear();
248 if constexpr(HasMainLoop)
254 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
255 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
257 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
258 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
282 c_thread_buf_per_scale.Clear();
288 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
291 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
296 using mfma_input_type =
301 a_thread_vec.template AsType<mfma_input_type>(),
302 b_thread_vec.template AsType<mfma_input_type>(),
303 c_thread_buf_per_scale.GetVectorTypeReference(
I0));
316 b_scale_thread_copy.Run(b_scale_grid_desc,
322 b_scale_thread_copy.MoveSrcSliceWindow(
323 b_scale_grid_desc, b_scale_thread_copy_step.At(
Number<0>{}));
326 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
327 b_scale_thread_copy_step.At(
Number<1>{}));
330 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
331 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
335 }
while(i < (num_loop - 1));
363 c_thread_buf_per_scale.Clear();
369 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
372 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
377 using mfma_input_type =
381 a_thread_vec.template AsType<mfma_input_type>(),
382 b_thread_vec.template AsType<mfma_input_type>(),
383 c_thread_buf_per_scale.GetVectorTypeReference(
I0));
398 using Base::a_thread_copy_;
399 using Base::a_thread_desc_;
400 using Base::b_thread_copy_;
401 using Base::b_thread_desc_;
402 using Base::c_thread_desc_;
__host__ __device__ constexpr auto make_static_buffer(Number< N >)
Definition static_buffer.hpp:186
int32_t index_t
Definition ck.hpp:299
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
integral_constant< index_t, N > Number
Definition number.hpp:12
TailNumber
Definition blkgemmpipe_scheduler.hpp:31
@ Full
Definition blkgemmpipe_scheduler.hpp:49
constexpr detail::ignore_t ignore
Definition utility/ignore.hpp:20
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
BlockGemmPipelineScheduler
Definition blkgemmpipe_scheduler.hpp:25
@ Intrawave
Definition blkgemmpipe_scheduler.hpp:26
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_pipeline_base.
Definition blockwise_gemm_pipeline_xdlops_base.hpp:222
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:280
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:239
static constexpr auto xdlops_gemm
Definition blockwise_gemm_pipeline_xdlops_base.hpp:54
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition blockwise_gemm_pipeline_xdlops_base.hpp:57
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:360
__host__ static __device__ constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:266
__host__ static __device__ constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:294
static constexpr index_t AMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:60
__host__ static __device__ constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:253
__host__ __device__ constexpr auto & GetCThreadBuffer()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:111
static constexpr auto I0
Definition blockwise_gemm_pipeline_xdlops_base.hpp:36
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:160
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:189
static constexpr index_t KRepeat
Definition blockwise_gemm_pipeline_xdlops_base.hpp:64
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:359
static constexpr index_t BMmaKStride
Definition blockwise_gemm_pipeline_xdlops_base.hpp:61
__host__ static __device__ constexpr auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:341
__host__ static __device__ constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition blockwise_gemm_pipeline_xdlops_base.hpp:307
__host__ static __device__ constexpr auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition blockwise_gemm_pipeline_xdlops_base.hpp:324
ck::BlockwiseGemmXdlops_pipeline_v1_b_scale< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Base BlockwiseGemmXdlops_pipeline_base< BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack > Base
Definition blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp:102
ck::BlockwiseGemmXdlops_pipeline_v1_b_scale< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::GlobalBufferNum static constexpr index_t GlobalBufferNum
Definition blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp:147
ck::BlockwiseGemmXdlops_pipeline_v1_b_scale< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::PrefillStages static constexpr index_t PrefillStages
Definition blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp:146
ck::BlockwiseGemmXdlops_pipeline_v1_b_scale< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, const BScaleGridDesc &b_scale_grid_desc, const BScaleThreadDesc &b_scale_thread_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, const BScaleThreadTransferStep &b_scale_thread_copy_step, index_t num_loop, index_t num_loop_per_scale) const
Definition blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp:181
ck::BlockwiseGemmXdlops_pipeline_v1_b_scale< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop static __host__ constexpr bool BlockHasHotloop(index_t num_loop)
Definition blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp:149
ck::BlockwiseGemmXdlops_pipeline_v1_b_scale< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::c_thread_desc_ static constexpr auto c_thread_desc_
Definition blockwise_gemm_pipeline_xdlops_base.hpp:378
ck::BlockwiseGemmXdlops_pipeline_v1_b_scale< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::xdlops_gemm static constexpr auto xdlops_gemm
Definition blockwise_gemm_pipeline_xdlops_base.hpp:54
ck::BlockwiseGemmXdlops_pipeline_v1_b_scale< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::b_block_desc_n0_n1_n2_k static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:360
ck::BlockwiseGemmXdlops_pipeline_v1_b_scale< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::b_thread_copy_ BThreadCopy b_thread_copy_
Definition blockwise_gemm_pipeline_xdlops_base.hpp:402
ck::BlockwiseGemmXdlops_pipeline_v1_b_scale< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::I0 static constexpr auto I0
Definition blockwise_gemm_pipeline_xdlops_base.hpp:36
ck::BlockwiseGemmXdlops_pipeline_v1_b_scale< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::a_thread_desc_ static constexpr auto a_thread_desc_
Definition blockwise_gemm_pipeline_xdlops_base.hpp:366
ck::BlockwiseGemmXdlops_pipeline_v1_b_scale< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::a_block_desc_m0_m1_m2_k static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition blockwise_gemm_pipeline_xdlops_base.hpp:359
ck::BlockwiseGemmXdlops_pipeline_v1_b_scale< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::b_thread_desc_ static constexpr auto b_thread_desc_
Definition blockwise_gemm_pipeline_xdlops_base.hpp:372
ck::BlockwiseGemmXdlops_pipeline_v1_b_scale< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::PrefetchStages static constexpr index_t PrefetchStages
Definition blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp:145
ck::BlockwiseGemmXdlops_pipeline_v1_b_scale< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::a_thread_copy_ AThreadCopy a_thread_copy_
Definition blockwise_gemm_pipeline_xdlops_base.hpp:401
ck::BlockwiseGemmXdlops_pipeline_v1_b_scale< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::ComputeDataTypeBuf typename Base::ComputeDataTypeBuf ComputeDataTypeBuf
Definition blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp:143
ck::BlockwiseGemmXdlops_pipeline_v1_b_scale< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum static __host__ constexpr TailNumber BlockLoopTailNum(index_t num_loop)
Definition blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp:154
Definition blockwise_gemm_pipeline_xdlops_v1_b_scale.hpp:37
Definition functional2.hpp:33
Definition dtype_vector.hpp:10