reference_rmsnorm2d_fwd.hpp Source File

reference_rmsnorm2d_fwd.hpp Source File#

Composable Kernel: reference_rmsnorm2d_fwd.hpp Source File
reference_rmsnorm2d_fwd.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
9
10namespace ck_tile {
11
12// Note: for simplicity, each functor only care about single M
14{
15 template <typename OutDataType, typename AccDataType>
17 {
18 const int N = acc.mDesc.get_lengths()[1];
19 for(int n = 0; n < N; ++n)
20 {
21 o(m, n) = ck_tile::type_convert<OutDataType>(acc(m, n));
22 }
23 }
24
25 template <typename OutDataType, typename AccDataType>
26 auto operator()(int m, const HostTensor<AccDataType>& acc)
27 {
29 operator()(m, o, acc);
30 return o;
31 }
32};
33
34template <typename XDataType,
35 typename GammaDataType,
36 typename ComputeDataType,
37 typename YDataType,
38 typename InvRmsDataType,
39 typename UnquantYDataType,
40 typename Epilogue = reference_rmsnorm2d_default_epilogue>
42 const HostTensor<GammaDataType>& gamma_n,
45 HostTensor<UnquantYDataType>& unquant_y_m_n,
46 ComputeDataType epsilon,
47 Epilogue epilogue_functor = {},
48 const int use_model_sensitive_rmsnorm =
50{
51 auto rmsnorm2d_fwd_func = [&](auto m) {
52 const int N = x_m_n.mDesc.get_lengths()[1];
53
54 ComputeDataType mean_square = 0;
55 ComputeDataType divisor = 0;
56
57 for(int n = 0; n < N; ++n)
58 {
59 ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
60 mean_square += x * x;
61 }
62
63 mean_square = mean_square / N;
64 divisor = ck_tile::type_convert<ComputeDataType>(1) / ck_tile::sqrt(mean_square + epsilon);
65
66 if constexpr(!std::is_same_v<InvRmsDataType, ck_tile::null_type>)
67 invRms_m(m) = ck_tile::type_convert<InvRmsDataType>(divisor);
68
70 for(int n = 0; n < N; ++n)
71 {
72 ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
73 ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
74 if(use_model_sensitive_rmsnorm ==
75 static_cast<int>(
76 Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL)) // 0: for no specific model
77 {
78 acc(m, n) = x * divisor * gamma;
79 }
80 else if(use_model_sensitive_rmsnorm ==
81 static_cast<int>(Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE)) // 1: for T5-like model
82 {
83 if constexpr(std::is_same_v<XDataType, ck_tile::bf16_t>)
84 {
85 const auto tmp0 = float_to_bf16<bf16_rounding_mode::standard>(x * divisor);
87 type_convert<ComputeDataType>(tmp0) * gamma);
88 const auto rmsn_ = type_convert<ComputeDataType>(tmp1);
89 acc(m, n) = rmsn_;
90 }
91 else
92 {
93 const auto tmp = type_convert<XDataType>(x * divisor);
94 const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma;
95 acc(m, n) = rmsn_;
96 }
97 }
98 }
99
100 if constexpr(!std::is_same_v<UnquantYDataType, ck_tile::null_type>)
101 {
102 epilogue_functor(m, unquant_y_m_n, y_m_n, acc);
103 }
104 else
105 {
106 epilogue_functor(m, y_m_n, acc);
107 }
108 };
109
110 make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])(
111 std::thread::hardware_concurrency());
112}
113
114} // namespace ck_tile
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition tile/host/host_tensor.hpp:329
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant< rounding >={})
Definition bfloat16.hpp:284
@ NO_SPECIFIC_MODEL
Definition rmsnorm2d_fwd_traits.hpp:42
@ T5_MODEL_LIKE
Definition rmsnorm2d_fwd_traits.hpp:46
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition bfloat16.hpp:413
void reference_rmsnorm2d_fwd(const HostTensor< XDataType > &x_m_n, const HostTensor< GammaDataType > &gamma_n, HostTensor< YDataType > &y_m_n, HostTensor< InvRmsDataType > &invRms_m, HostTensor< UnquantYDataType > &unquant_y_m_n, ComputeDataType epsilon, Epilogue epilogue_functor={}, const int use_model_sensitive_rmsnorm=static_cast< int >(Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL))
Definition reference_rmsnorm2d_fwd.hpp:41
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
const std::vector< std::size_t > & get_lengths() const
Definition tile/host/host_tensor.hpp:198
Definition tile/host/host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition tile/host/host_tensor.hpp:390
decltype(auto) get_strides() const
Definition tile/host/host_tensor.hpp:394
Descriptor mDesc
Definition tile/host/host_tensor.hpp:800
Definition reference_rmsnorm2d_fwd.hpp:14
auto operator()(int m, const HostTensor< AccDataType > &acc)
Definition reference_rmsnorm2d_fwd.hpp:26
void operator()(int m, HostTensor< OutDataType > &o, const HostTensor< AccDataType > &acc)
Definition reference_rmsnorm2d_fwd.hpp:16