block_universal_gemm_as_aquant_bs_cr.hpp Source File

block_universal_gemm_as_aquant_bs_cr.hpp Source File#

Composable Kernel: block_universal_gemm_as_aquant_bs_cr.hpp Source File
block_universal_gemm_as_aquant_bs_cr.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
12
13namespace ck_tile {
14
15template <typename Problem>
17{
20
21 template <typename T>
22 CK_TILE_DEVICE static float cvt_scale_to_fp32(T scale)
23 {
24 float scale_reg_f = 0.f;
25 if constexpr(std::is_same_v<AQDataType, ck_tile::fp8_t>)
26 {
27 scale_reg_f =
29 }
30 else if constexpr(std::is_same_v<AQDataType, ck_tile::bf8_t>)
31 {
32 scale_reg_f =
34 }
35 else if constexpr(std::is_same_v<AQDataType, float>)
36 {
37 scale_reg_f = ck_tile::bit_cast<float>(scale);
38 }
39 else
40 {
41 static_assert(false, "AQDataType must be float, fp8_t or bf8_t.");
42 }
43 return scale_reg_f;
44 }
45};
46
47// A is block window on shared memory
48// AQ (scale tensor) is block distributed tensor.
49// Consecutive QuantGroupSize elements of A are quantized with a separate scale.
50// B is block window on shared memory
51// C is block distributed tensor
52template <typename Problem_,
53 typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy,
54 index_t UnaryOpSize_ = 8>
56{
57 private:
58 template <typename PipelineProblem_, typename GemmPolicy_>
59 struct GemmTraits_
60 {
62 using Policy = remove_cvref_t<GemmPolicy_>;
70
71 static constexpr index_t kBlockSize = Problem::kBlockSize;
72 static constexpr auto Scheduler = Problem::Scheduler;
73
74 // Threadblock GEMM tile size
75 static constexpr index_t MPerBlock = BlockGemmShape::kM;
76 static constexpr index_t NPerBlock = BlockGemmShape::kN;
77 static constexpr index_t KPerBlock = BlockGemmShape::kK;
78 static constexpr index_t AQPerBlock = KPerBlock / QuantGroupSize::kK;
79
80 static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
81 using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
82
83 // number of warps along M and N for threadblock's GEMM problem size
84 static constexpr index_t MWarp = config.template at<1>();
85 static constexpr index_t NWarp = config.template at<2>();
86
87 using I0 = number<0>;
88 using I1 = number<1>;
89
90 static_assert(MWarp == BlockGemmShape::BlockWarps::at(I0{}),
91 "Error! WarpGemm's MWarp is not consistent with BlockGemmShape!");
92 static_assert(NWarp == BlockGemmShape::BlockWarps::at(I1{}),
93 "Error! WarpGemm's NWarp is not consistent with BlockGemmShape!");
94 static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(I0{}),
95 "Error! WarpGemm's M is not consistent with BlockGemmShape!");
96 static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(I1{}),
97 "Error! WarpGemm's N is not consistent with BlockGemmShape!");
98
99 static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
100 static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
101 static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
102
103 static constexpr index_t QScalesPerBlockRow =
104 integer_divide_ceil(KPerBlock, QuantGroupSize::kK);
105 static constexpr index_t QScalesPerWarpGemmRow =
106 integer_divide_ceil(WarpGemm::kK, QuantGroupSize::kK);
107
108 static constexpr index_t KIterPerQScale = KIterPerWarp / QScalesPerBlockRow;
109
110 static_assert(QuantGroupSize::kK % WarpGemm::kK == 0,
111 "Error! WarpGemm::kK should be a multiple of QuantGroupSize");
112 static_assert(QScalesPerWarpGemmRow == 1,
113 "Error! QuantGroupSize shouldn't be smaller than WarpGemm::kK");
114 static_assert(KIterPerWarp % QScalesPerBlockRow == 0,
115 "Error! KItersPerWarp should be a multiple of QscalesPerBlockRow");
116
117 static_assert(KPerBlock / QuantGroupSize::kK > 0,
118 "Error! Each row of blockgemm should have a separate scale");
119
120 static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
121 "Error! Warps should cover all Block tile!");
122 static_assert(NIterPerWarp * NWarp * WarpGemm::kN == NPerBlock,
123 "Error! Warps should cover all Block tile!");
124
125 // Currently tested combinations (A, AQ, B)
126 // 1. fp8, fp32, fp8 -> f32
127 // 2. bf8, fp32, bf8 -> f32
128 // 3. i4, (fp8/fp32) fp8 -> f32
129 // 4. i4, (fp8/fp32) bf8 -> f32
130 static_assert((std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
131 std::is_same_v<ADataType, bf8_t>) &&
132 (std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>) &&
133 (std::is_same_v<AQDataType, float> ||
134 std::is_same_v<AQDataType, ck_tile::fp8_t> ||
135 std::is_same_v<AQDataType, ck_tile::bf8_t>) &&
136 (std::is_same_v<ComputeDataType, fp8_t> ||
137 std::is_same_v<ComputeDataType, bf8_t>) &&
138 std::is_same_v<CDataType, fp32_t>);
139
140 static constexpr index_t InterWaveSchedulingMacClusters = 1;
141
142 static constexpr index_t KPack = WarpGemm::kKPerThread;
143 static constexpr index_t KPerThread = KIterPerWarp * WarpGemm::kKPerThread;
144
145 static constexpr bool PreshuffleQuant = Problem::Traits::PreshuffleQuant;
146 static constexpr bool TransposeC = Problem::TransposeC;
147 };
148
149 public:
150 using Traits = GemmTraits_<Problem_, Policy_>;
151
157
159
162
163 static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
164 static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
165 static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
166
167 static constexpr index_t MWarp = Traits::MWarp;
168 static constexpr index_t NWarp = Traits::NWarp;
169
170 static constexpr auto Scheduler = Traits::Scheduler;
171
172 using AWarpDstr = typename WarpGemm::AWarpDstr;
173 using BWarpDstr = typename WarpGemm::BWarpDstr;
174 using CWarpDstr = typename WarpGemm::CWarpDstr;
175
176 using AWarpTensor = typename WarpGemm::AWarpTensor;
177 using BWarpTensor = typename WarpGemm::BWarpTensor;
178 using CWarpTensor = typename WarpGemm::CWarpTensor;
179
180 static_assert(std::is_same_v<typename WarpGemm::CDataType, float>);
181
182 static constexpr auto a_warp_y_lengths =
183 to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
184 static constexpr auto b_warp_y_lengths =
185 to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
186 static constexpr auto c_warp_y_lengths =
187 to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
188
192
193 static constexpr index_t APackedSize =
195 static constexpr index_t BPackedSize =
197
198 using I0 = number<0>;
199 using I1 = number<1>;
200
202 {
203 constexpr index_t KPerThread = Traits::KPerThread;
204 constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
205
206 constexpr index_t KPerInnerLoop =
207 ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
208
209 constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
210
211 using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
214
215 constexpr auto a_block_outer_dstr_encoding =
222 constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
223 a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
224
225 return a_block_dstr_encode;
226 }
227
229 {
230 constexpr index_t KPerThread = Traits::KPerThread;
231 constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters;
232 constexpr index_t KPerInnerLoop =
233 ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread);
234 constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread;
235
236 using KIterSeq = std::conditional_t<Scheduler == GemmPipelineScheduler::Interwave,
239
240 constexpr auto b_block_outer_dstr_encoding =
247 constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
248 b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
249
250 return b_block_dstr_encode;
251 }
252
253 private:
254 template <GemmPipelineScheduler Scheduler, typename GemmTraits>
255 struct BlockGemmImpl
256 {
257 };
258
259 template <typename GemmTraits>
260 struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
261 {
262 private:
263 CK_TILE_DEVICE static float exchange_quant_value_across_lanes(float scale_reg,
264 index_t pull_from_lane)
265 {
266 // cross lane ops
267 uint32_t scale_reg_dword;
268
269 if constexpr(std::is_same_v<AQDataType, float>)
270 {
271 scale_reg_dword = ck_tile::bit_cast<uint32_t>(scale_reg);
272 }
273 else
274 {
275 scale_reg_dword = static_cast<uint32_t>(scale_reg);
276 }
277
278 int gathered_scale_reg = __builtin_amdgcn_ds_bpermute(
279 pull_from_lane << 2, __builtin_bit_cast(int, scale_reg_dword));
280
281 return Base::cvt_scale_to_fp32(gathered_scale_reg);
282 }
283
284 template <typename AQBlockTensor,
285 bool PreShuffleQuant,
286 bool TransposeC,
287 int32_t mIter,
288 int32_t kQScale>
289 struct AQPicker
290 {
292 AQPicker(AQBlockTensor& aq_block_tensor_) : aq_block_tensor(aq_block_tensor_)
293 {
294 if constexpr(Traits::TransposeC) // transposed C
295 {
296 index_t reg_offset =
297 Traits::PreshuffleQuant ? mIter : mIter * Traits::AQPerBlock + kQScale;
298 auto scale_reg = aq_block_tensor.get_thread_buffer()[reg_offset];
299 if constexpr(Traits::PreshuffleQuant)
300 {
301 auto pull_from_lane =
302 (__lane_id() & (Traits::WarpGemm::kN - 1)) * Traits::AQPerBlock +
303 kQScale;
304
305 scale_reg_f = exchange_quant_value_across_lanes(scale_reg, pull_from_lane);
306 }
307 else
308 {
309 scale_reg_f = Base::cvt_scale_to_fp32(scale_reg);
310 }
311 }
312 }
313 template <uint32_t c_row = 0>
314 CK_TILE_DEVICE float pick()
315 {
316 if constexpr(Traits::TransposeC)
317 {
318 // pre-computed scale_reg_f is shared by entire column when TransposeC is true
319 return scale_reg_f;
320 }
321 else
322 {
323 if constexpr(Traits::PreshuffleQuant)
324 {
325 // A view is created on top of the preshuffled AQ, where each row of
326 // the view is composed of a row from a warp tile within an AQ block
327 // tile. Multiple warp tile rows that belong to the same block tile
328 // are laid out as consecutive rows.
329 //
330 // When we need to multiply a C warp tile with an AQ warp tile,
331 // thread 0 in the warp will load AQ_warp_tile[0], thread 1 will
332 // load AQ_warp_tile[1], and so on, up to thread 63, which will load
333 // AQ_warp_tile[63]. The VGPR file in the warp acts similarly to LDS
334 // in this context, but we use cross-lane operations to access the
335 // data. (Cross-lane operations are faster than using LDS.)
336 //
337 // Note that when the size of the AQ warp tile is smaller than the
338 // warp size, you need to pad the rows in the view to ensure that
339 // each thread can read one element.
340
341 // For a warp tile of [16x16x32], take thread 0 as an
342 // example. Its VGPR[0] stores the value from C_tile[0,0],
343 // VGPR[1] stores C_tile[1,0], VGPR[2] stores C_tile[2,0],
344 // and VGPR[3] stores C_tile[3,0]. This means VGPR[0] should
345 // be multiplied by AQ_tile[0, 0], VGPR[1] by AQ_tile[1, 0],
346 // VGPR[2] by AQ_tile[2, 0], and VGPR[3] by AQ_tile[3, 0].
347
348 // Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1,
349 // 0] from thread 1, ..., and AQ_tile[3, 0] from thread 3.
350
351 constexpr uint32_t kTileRowsOfCPerThread = 4;
352 decltype(threadIdx.x) pull_from_lane = 0;
353 if constexpr(WarpGemm::kM == 16)
354 {
355 pull_from_lane =
356 (__lane_id() / Traits::WarpGemm::kN * kTileRowsOfCPerThread +
357 c_row) *
358 Traits::QScalesPerBlockRow +
359 kQScale;
360 }
361 else if constexpr(WarpGemm::kM == 32)
362 {
363 pull_from_lane =
364 (__lane_id() / Traits::WarpGemm::kN * kTileRowsOfCPerThread +
365 ((c_row >> 2) << 3) + (c_row & 0b11)) *
366 Traits::QScalesPerBlockRow +
367 kQScale;
368 }
369 else
370 {
371 static_assert(false, "WarpGemm::kM is not 16 nor 32.");
372 }
373 auto& scale_reg = aq_block_tensor.get_thread_buffer()[mIter];
374
375 return exchange_quant_value_across_lanes(scale_reg, pull_from_lane);
376 }
377 else
378 {
379 // Need to multiply aquant with accumulated C
380 //
381 // The accumulated C tile has the standard distribution. For example
382 // lane 0 holds elements [0,0], [1,0], [2,0], [3,0], [8,0], [9,0],
383 // [10,0], [11,0], [16,0], [17,0], [18,0], [19,0], [24,0], [25,0],
384 // [26,0], [27,0].
385 //
386 // These elements are in different rows, need to get the scale value
387 // for the corresponding row.
388 // Based on aquant's tile distribution, it can be inferred which
389 // lane holds the relevant scale. For example, the scales
390 // corresponding to the 16 elements held by lane 0 are held by lanes
391 // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27
392 // respectively.
393 //
394 // These scales can be obtained using __builtin_amdgcn_ds_bpermute.
395
396 // MIters per warp
397 constexpr index_t mIters_per_warp = get_warp_size() / WarpGemm::kM;
398
399 // Reg block offset based on mIter
400 constexpr index_t reg_block_offset =
401 ((mIter / mIters_per_warp) * Traits::AQPerBlock);
402
403 constexpr index_t lane_base_offset =
404 (mIter % mIters_per_warp) * WarpGemm::kM;
405
406 // Scale tensor offset along K
407 constexpr index_t src_reg_offset = reg_block_offset + kQScale;
408 // Directly index into thread buffer corresponding to
409 // desired row coefficient
410 auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset];
411
412 constexpr uint32_t kTileRows = 4;
413 constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows;
414 constexpr uint32_t reg_offset_for_row_data = c_row * WarpGemm::kCMLane;
415 // Multiply by 4 because output is stored in tiles of 4
416 // x CNLane
417 constexpr uint32_t row_base =
418 ((reg_offset_for_row_data / kTiledCMsPerWarp) * kTiledCMsPerWarp) +
419 ((reg_offset_for_row_data % kTiledCMsPerWarp) / WarpGemm::kCMLane);
420
421 // Lane index to source scale from
422 uint32_t src_lane_idx =
423 lane_base_offset + row_base + (__lane_id() / WarpGemm::kN * kTileRows);
424
425 return exchange_quant_value_across_lanes(scale_reg, src_lane_idx);
426 }
427 }
428 }
429
430 AQBlockTensor& aq_block_tensor;
431 float scale_reg_f = 0.0f;
432 };
433
434 public:
435 static constexpr auto ALdsTileDistr =
437 static constexpr auto BLdsTileDistr =
439
440 using ALdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(ALdsTileDistr));
441 using BLdsTile = decltype(make_static_distributed_tensor<ComputeDataType>(BLdsTileDistr));
442
443 ALdsTile a_warp_tile_;
444 BLdsTile b_warp_tile_;
445
446 template <typename ASmemBlockWindow, typename BSmemBlockWindow>
447 CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
448 const BSmemBlockWindow& b_block_window)
449 {
450 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
451 {
452 static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
453 std::is_same_v<ComputeDataType, bf8_t>);
454 Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window);
455 }
456 else
457 {
458 load_tile(a_warp_tile_, a_block_window);
459 }
460 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
461 {
462 static_assert(std::is_same_v<ComputeDataType, fp8_t> ||
463 std::is_same_v<ComputeDataType, bf8_t>);
464 Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window);
465 }
466 else
467 {
468 load_tile(b_warp_tile_, b_block_window);
469 }
470 }
471
472 // C += A * B
473 template <typename CBlockTensor,
474 typename AQBlockTensor,
475 typename ASmemBlockWindow,
476 typename BSmemBlockWindow>
477 CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
478 AQBlockTensor& aq_block_tensor,
479 [[maybe_unused]] ASmemBlockWindow& a_block_window,
480 [[maybe_unused]] BSmemBlockWindow& b_block_window)
481 {
482 static_assert(std::is_same_v<CDataType, typename CBlockTensor::DataType>,
483 "The CDataType as defined in traits should be the same as corresponding "
484 "C block tensor data type!");
485 constexpr auto warp_size = get_warp_size();
486
487 // hot loop:
488 static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
489 static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
490 CWarpTensor c_warp_tensor;
491
492 static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) {
493 static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) {
494 constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale;
495
496 AWarpTensor a_warp_tensor;
497 a_warp_tensor.get_thread_buffer() =
498 a_warp_tile_.get_y_sliced_thread_data(
499 merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
500 merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
501
502 BWarpTensor b_warp_tensor;
503 b_warp_tensor.get_thread_buffer() =
504 b_warp_tile_.get_y_sliced_thread_data(
505 merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
506 merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
507
508 if constexpr(kIterInQScale == 0)
509 {
510 c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor);
511 }
512 else
513 {
514 WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
515 }
516 });
517
518 constexpr auto tbuf_offset =
519 number<typename CBlockTensor::ThreadTensorDesc{}.calculate_offset(
520 merge_sequences(sequence<mIter, nIter>{},
522 CBlockTensor::PackedSize>{};
523
524 AQPicker<AQBlockTensor,
525 Traits::PreshuffleQuant,
526 Traits::TransposeC,
527 mIter,
528 kQScale>
529 aq_picker(aq_block_tensor);
530
531 static_for<0, WarpGemm::kM * WarpGemm::kN / warp_size, 1>{}(
532 [&](auto c_row) {
533 float scale_reg_f = aq_picker.template pick<c_row>();
534 c_block_tensor.get_thread_buffer()[tbuf_offset + c_row] +=
535 (c_warp_tensor.get_thread_buffer()[c_row] * scale_reg_f);
536 });
537 });
538 });
539 });
540 }
541 };
542
543 public:
544 CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
545 {
546 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
553
554 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
555 c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
556 constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
557 auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
558
559 return c_block_tensor;
560 }
561
562 template <typename ASmemBlockWindow, typename BSmemBlockWindow>
563 CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
564 const BSmemBlockWindow& b_block_window)
565 {
566 block_gemm_impl_.LocalPrefetch(a_block_window, b_block_window);
567 }
568
569 // C += A * B
570 template <typename CBlockTensor,
571 typename AQBlockTensor,
572 typename ASmemBlockWindow,
573 typename BSmemBlockWindow>
574 CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
575 AQBlockTensor& aq_block_tensor,
576 const ASmemBlockWindow& a_block_window,
577 const BSmemBlockWindow& b_block_window)
578 {
579 block_gemm_impl_(c_block_tensor, aq_block_tensor, a_block_window, b_block_window);
580 }
581
582 private:
583 BlockGemmImpl<Scheduler, Traits> block_gemm_impl_{};
584};
585
586} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
CK_TILE_DEVICE float amd_assembly_fp8_to_fp32(uint32_t src)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:258
CK_TILE_DEVICE float amd_assembly_bf8_to_fp32(uint32_t src)
Definition tile/ops/elementwise/unary_element_wise_operation.hpp:265
Definition tile/core/algorithm/cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_HOST_DEVICE constexpr auto merge_sequences(Seqs...)
Definition tile/core/container/sequence.hpp:826
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
int32_t int32_t
Definition integer.hpp:10
CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple< number< Is >... >)
Definition tile/core/container/sequence.hpp:1055
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition tile/core/container/sequence.hpp:1026
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
GemmPipelineScheduler
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:14
@ Intrawave
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:16
@ Interwave
Definition gemm_pipeline_ag_bg_cr_scheduler.hpp:17
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
int32_t index_t
Definition ck.hpp:299
__device__ constexpr index_t get_warp_size()
Definition get_id.hpp:10
unsigned int uint32_t
Definition stdint.h:126
Definition block_universal_gemm_as_aquant_bs_cr.hpp:56
static constexpr auto b_warp_y_index_zeros
Definition block_universal_gemm_as_aquant_bs_cr.hpp:190
static constexpr auto c_warp_y_index_zeros
Definition block_universal_gemm_as_aquant_bs_cr.hpp:191
typename WarpGemm::BWarpDstr BWarpDstr
Definition block_universal_gemm_as_aquant_bs_cr.hpp:173
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow &a_block_window, const BSmemBlockWindow &b_block_window)
Definition block_universal_gemm_as_aquant_bs_cr.hpp:563
static constexpr auto a_warp_y_index_zeros
Definition block_universal_gemm_as_aquant_bs_cr.hpp:189
typename WarpGemm::BWarpTensor BWarpTensor
Definition block_universal_gemm_as_aquant_bs_cr.hpp:177
static constexpr auto a_warp_y_lengths
Definition block_universal_gemm_as_aquant_bs_cr.hpp:182
remove_cvref_t< typename Traits::BDataType > BDataType
Definition block_universal_gemm_as_aquant_bs_cr.hpp:154
static constexpr index_t APackedSize
Definition block_universal_gemm_as_aquant_bs_cr.hpp:193
static constexpr index_t MWarp
Definition block_universal_gemm_as_aquant_bs_cr.hpp:167
remove_cvref_t< InterleavedPKTypeLoader< ComputeDataType, UnaryOpSize_ > > Loader
Definition block_universal_gemm_as_aquant_bs_cr.hpp:160
remove_cvref_t< typename Traits::AQDataType > AQDataType
Definition block_universal_gemm_as_aquant_bs_cr.hpp:153
BlockGemmAQuantBase< Problem_ > Base
Definition block_universal_gemm_as_aquant_bs_cr.hpp:158
static constexpr auto Scheduler
Definition block_universal_gemm_as_aquant_bs_cr.hpp:170
typename WarpGemm::AWarpTensor AWarpTensor
Definition block_universal_gemm_as_aquant_bs_cr.hpp:176
typename WarpGemm::CWarpTensor CWarpTensor
Definition block_universal_gemm_as_aquant_bs_cr.hpp:178
remove_cvref_t< typename Traits::ADataType > ADataType
Definition block_universal_gemm_as_aquant_bs_cr.hpp:152
typename WarpGemm::CWarpDstr CWarpDstr
Definition block_universal_gemm_as_aquant_bs_cr.hpp:174
static CK_TILE_DEVICE constexpr auto MakeBBlockDistributionEncode()
Definition block_universal_gemm_as_aquant_bs_cr.hpp:228
static constexpr index_t BPackedSize
Definition block_universal_gemm_as_aquant_bs_cr.hpp:195
remove_cvref_t< typename Traits::CDataType > CDataType
Definition block_universal_gemm_as_aquant_bs_cr.hpp:156
static constexpr auto b_warp_y_lengths
Definition block_universal_gemm_as_aquant_bs_cr.hpp:184
static constexpr index_t NIterPerWarp
Definition block_universal_gemm_as_aquant_bs_cr.hpp:165
static constexpr index_t KIterPerWarp
Definition block_universal_gemm_as_aquant_bs_cr.hpp:163
static constexpr index_t NWarp
Definition block_universal_gemm_as_aquant_bs_cr.hpp:168
remove_cvref_t< typename Traits::ComputeDataType > ComputeDataType
Definition block_universal_gemm_as_aquant_bs_cr.hpp:155
static CK_TILE_DEVICE constexpr auto MakeCBlockTile()
Definition block_universal_gemm_as_aquant_bs_cr.hpp:544
number< 1 > I1
Definition block_universal_gemm_as_aquant_bs_cr.hpp:199
remove_cvref_t< typename Traits::WarpGemm > WarpGemm
Definition block_universal_gemm_as_aquant_bs_cr.hpp:161
static constexpr index_t MIterPerWarp
Definition block_universal_gemm_as_aquant_bs_cr.hpp:164
static constexpr auto c_warp_y_lengths
Definition block_universal_gemm_as_aquant_bs_cr.hpp:186
CK_TILE_DEVICE void operator()(CBlockTensor &c_block_tensor, AQBlockTensor &aq_block_tensor, const ASmemBlockWindow &a_block_window, const BSmemBlockWindow &b_block_window)
Definition block_universal_gemm_as_aquant_bs_cr.hpp:574
typename WarpGemm::AWarpDstr AWarpDstr
Definition block_universal_gemm_as_aquant_bs_cr.hpp:172
static CK_TILE_DEVICE constexpr auto MakeABlockDistributionEncode()
Definition block_universal_gemm_as_aquant_bs_cr.hpp:201
GemmTraits_< Problem_, Policy_ > Traits
Definition block_universal_gemm_as_aquant_bs_cr.hpp:150
number< 0 > I0
Definition block_universal_gemm_as_aquant_bs_cr.hpp:198
Definition block_universal_gemm_as_aquant_bs_cr.hpp:17
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition block_universal_gemm_as_aquant_bs_cr.hpp:19
static CK_TILE_DEVICE float cvt_scale_to_fp32(T scale)
Definition block_universal_gemm_as_aquant_bs_cr.hpp:22
remove_cvref_t< typename Problem::AQDataType > AQDataType
Definition block_universal_gemm_as_aquant_bs_cr.hpp:18
Definition tile/core/numeric/numeric.hpp:81
Definition tile/core/container/sequence.hpp:49
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192