warp_gemm_attribute_wmma_impl_16bit_traits.hpp Source File

warp_gemm_attribute_wmma_impl_16bit_traits.hpp Source File#

Composable Kernel: warp_gemm_attribute_wmma_impl_16bit_traits.hpp Source File
warp_gemm_attribute_wmma_impl_16bit_traits.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7namespace ck_tile {
8// fp16 specialization - GFX11
9template <>
10struct WmmaTraits<gfx11_t, fp16_t, fp16_t, float, 16, 16, 16>
11 : WmmaTraitsBase<gfx11_t, fp16_t, fp16_t, float>
12{
13 template <bool clamp = false>
14 CK_TILE_DEVICE static CVecType
15 wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
16 {
17#ifdef __gfx11__
18 return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_vec, b_vec, c_vec);
19#else
20 ck_tile::ignore = a_vec;
21 ck_tile::ignore = b_vec;
22 ck_tile::ignore = c_vec;
23 return CVecType{0.f};
24#endif
25 }
26};
27
28// bf16 specialization - GFX11
29template <>
30struct WmmaTraits<gfx11_t, bf16_t, bf16_t, float, 16, 16, 16>
31 : WmmaTraitsBase<gfx11_t, bf16_t, bf16_t, float>
32{
33 template <bool clamp = false>
34 CK_TILE_DEVICE static CVecType
35 wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
36 {
37#ifdef __gfx11__
38 return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_vec, b_vec, c_vec);
39#else
40 ck_tile::ignore = a_vec;
41 ck_tile::ignore = b_vec;
42 ck_tile::ignore = c_vec;
43 return CVecType{0.f};
44#endif
45 }
46};
47
48// fp16 specialization - GFX12
49template <>
50struct WmmaTraits<gfx12_t, fp16_t, fp16_t, float, 16, 16, 16>
51 : WmmaTraitsBase<gfx12_t, fp16_t, fp16_t, float>
52{
53 template <bool clamp = false>
54 CK_TILE_DEVICE static CVecType
55 wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
56 {
57#ifdef __gfx12__
58 return __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_vec, b_vec, c_vec);
59#else
60 ck_tile::ignore = a_vec;
61 ck_tile::ignore = b_vec;
62 ck_tile::ignore = c_vec;
63 return CVecType{0.f};
64#endif
65 }
66};
67
68// bf16 specialization - GFX12
69template <>
70struct WmmaTraits<gfx12_t, bf16_t, bf16_t, float, 16, 16, 16>
71 : WmmaTraitsBase<gfx12_t, bf16_t, bf16_t, float>
72{
73 template <bool clamp = false>
74 CK_TILE_DEVICE static CVecType
75 wmma_intrinsic(const AVecType& a_vec, const BVecType& b_vec, const CVecType& c_vec)
76 {
77#ifdef __gfx12__
78 return __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_vec, b_vec, c_vec);
79#else
80 ck_tile::ignore = a_vec;
81 ck_tile::ignore = b_vec;
82 ck_tile::ignore = c_vec;
83 return CVecType{0.f};
84#endif
85 }
86};
87} // namespace ck_tile
#define CK_TILE_DEVICE
Definition config.hpp:41
Definition tile/core/algorithm/cluster_descriptor.hpp:13
bfloat16_t bf16_t
Definition bfloat16.hpp:113
_Float16 fp16_t
Definition half.hpp:110
constexpr detail::ignore_t ignore
Definition tile/core/utility/ignore.hpp:20
static CK_TILE_DEVICE CVecType wmma_intrinsic(const AVecType &a_vec, const BVecType &b_vec, const CVecType &c_vec)
Definition warp_gemm_attribute_wmma_impl_16bit_traits.hpp:35
static CK_TILE_DEVICE CVecType wmma_intrinsic(const AVecType &a_vec, const BVecType &b_vec, const CVecType &c_vec)
Definition warp_gemm_attribute_wmma_impl_16bit_traits.hpp:15
static CK_TILE_DEVICE CVecType wmma_intrinsic(const AVecType &a_vec, const BVecType &b_vec, const CVecType &c_vec)
Definition warp_gemm_attribute_wmma_impl_16bit_traits.hpp:75
static CK_TILE_DEVICE CVecType wmma_intrinsic(const AVecType &a_vec, const BVecType &b_vec, const CVecType &c_vec)
Definition warp_gemm_attribute_wmma_impl_16bit_traits.hpp:55
Definition warp_gemm_attribute_wmma_impl_base_traits.hpp:7
Definition warp_gemm_attribute_wmma_impl.hpp:19
Definition arch.hpp:363
Definition arch.hpp:366