gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp Source File

gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp Source File#

Composable Kernel: gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp Source File
gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
10
11namespace ck_tile {
12// Default policy for GemmPipelineAGmemBGmemCregComputeV5, except the block gemm method, it shares
13// the same vector size implementation, SmemSize, Global memory tile distiribution as the
14// UniversalGemm Pipeline Policy.
15// Default policy class should not be templated, put template on
16// member functions instead.
18 : public UniversalGemmBasePolicy<GemmPipelineAgBgCrCompV5DefaultPolicy>
19{
20 template <typename Problem>
21 CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
22 {
23 // using AccDataType = float;
24 using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
25 using WarpTile = typename Problem::BlockGemmShape::WarpTile;
26 using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
27 typename Problem::BDataType,
28 typename Problem::CDataType, // AccDataType
29 WarpTile::at(I0),
30 WarpTile::at(I1),
31 WarpTile::at(I2),
32 Problem::TransposeC>;
33
34 using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
35 typename Problem::BDataType,
36 typename Problem::CDataType,
37 BlockWarps,
38 WarpGemm>;
39
41 }
42
43 template <typename Problem>
45 {
46 constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
47 constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
48
49 return integer_least_multiple(sizeof(typename Problem::CDataType) * MPerBlock * NPerBlock,
50 16);
51 }
52
53 template <typename Problem>
55 {
56 constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
57 constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
58 constexpr index_t smem_size_c = GetSmemSizeC<Problem>();
59
60 return smem_size_a + smem_size_b >= smem_size_c ? (smem_size_a + smem_size_b)
61 : (smem_size_c);
62 }
63};
64} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
CK_TILE_HOST_DEVICE constexpr auto integer_least_multiple(X x, Y y)
Definition tile/core/numeric/math.hpp:155
int32_t index_t
Definition integer.hpp:9
Definition block_gemm_areg_breg_creg_v1_custom_policy.hpp:16
Definition block_gemm_areg_breg_creg_v1.hpp:18
Definition gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp:19
static CK_TILE_HOST_DEVICE constexpr auto GetBlockGemm()
Definition gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp:21
static CK_TILE_HOST_DEVICE constexpr index_t GetSmemSize()
Definition gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp:54
static CK_TILE_DEVICE constexpr index_t GetSmemSizeC()
Definition gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp:44
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:34
static constexpr auto I1
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:50
static CK_TILE_DEVICE constexpr index_t GetSmemSizeB()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:671
static constexpr auto I2
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:51
static CK_TILE_DEVICE constexpr index_t GetSmemSizeA()
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:661
static constexpr auto I0
Definition gemm_universal_pipeline_ag_bg_cr_policy.hpp:49