threadwise_welford.hpp Source File

threadwise_welford.hpp Source File#

Composable Kernel: threadwise_welford.hpp Source File
threadwise_welford.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
7
8namespace ck {
9
10// Assume
11// 1) XDesc is known at compile-time
12// 2) MeanVarDesc is known at compile-time
13// 3) XBuffer is static buffer
14// 4) MeanBuffer is static buffer
15// 5) VarBuffer is static buffer
16template <typename T, typename XThreadDesc_M_K, typename MeanVarThreadDesc_M>
18{
19 static constexpr auto x_thread_desc_m_k = XThreadDesc_M_K{};
20 static constexpr auto mean_var_thread_desc_m = MeanVarThreadDesc_M{};
21
22 static constexpr auto thread_x_length_m = x_thread_desc_m_k.GetLength(Number<0>{});
23 static constexpr auto thread_x_length_k = x_thread_desc_m_k.GetLength(Number<1>{});
24 static constexpr auto thread_mean_var_length_m = mean_var_thread_desc_m.GetLength(Number<0>{});
25
27 "lengths of source and mean/var buffer must match!");
28
29 __device__ constexpr ThreadwiseWelford() : cur_count_(0), max_count_(0) {}
30
31 __device__ inline void Update(T& mean, T& var, T x)
32 {
33 using ck::math::isnan;
34
35 if(isnan(x))
36 {
37 mean = x;
38 var = x;
39 }
40 else
41 {
42 T delta = x - mean;
43 mean += delta / cur_count_;
44 T delta2 = x - mean;
45 var += delta * delta2;
46 }
47 }
48
49 template <typename XBufferType, typename MeanBufferType, typename VarBufferType>
50 __device__ void
51 Run(const XBufferType& x_buf_m_k, MeanBufferType& mean_buf_m, VarBufferType& var_buf_m)
52 {
53 // FIXME - Better naming for var_buf_m
54
57 {
58 ++cur_count_;
59
61 constexpr index_t out_offset =
62 mean_var_thread_desc_m.CalculateOffset(make_tuple(iM));
63
64 constexpr auto in_offset =
65 x_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
66 Update(mean_buf_m(Number<out_offset>{}),
67 var_buf_m(Number<out_offset>{}),
68 x_buf_m_k[Number<in_offset>{}]);
69 });
70 }
71 });
72 };
73
76};
77
78template <typename T,
79 typename SrcMeanVarCountThreadDesc_M_K,
80 typename DstMeanVarThreadDesc_M,
81 bool GetActualVariance = false>
83{
84 static constexpr auto src_thread_desc_m_k = SrcMeanVarCountThreadDesc_M_K{};
85 static constexpr auto dst_thread_desc_m = DstMeanVarThreadDesc_M{};
86
87 static constexpr auto src_length_m = src_thread_desc_m_k.GetLength(Number<0>{});
88 static constexpr auto src_length_k = src_thread_desc_m_k.GetLength(Number<1>{});
89 static constexpr auto dst_length_m = dst_thread_desc_m.GetLength(Number<0>{});
90
91 static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
92
93 __device__ static void
94 Merge(T& mean_a, T& var_a, int32_t& count_a, T mean_b, T var_b, int32_t count_b)
95 {
96 int count = count_a + count_b;
97 T count_b_over_count = count == 0 ? type_convert<T>(0) : type_convert<T>(count_b) / count;
98 T delta = mean_b - mean_a;
99 mean_a += delta * count_b_over_count;
100 var_a += var_b + delta * delta * count_a * count_b_over_count;
101 count_a = count;
102 }
103
104 template <typename SrcMeanBufferType,
105 typename SrcVarBufferType,
106 typename SrcCountBufferType,
107 typename DstMeanBufferType,
108 typename DstVarBufferType,
109 typename DstCountBufferType>
110 __device__ static void Run(const SrcMeanBufferType& src_mean_buf,
111 const SrcVarBufferType& src_var_buf,
112 const SrcCountBufferType& src_count_buf,
113 DstMeanBufferType& dst_mean_buf,
114 DstVarBufferType& dst_var_buf,
115 DstCountBufferType& dst_count_buf)
116 {
117 static_for<0, src_length_m, 1>{}([&](auto iM) {
118 static_for<0, src_length_k, 1>{}([&](auto iK) {
119 constexpr auto src_offset = src_thread_desc_m_k.CalculateOffset(make_tuple(iM, iK));
120
121 Merge(dst_mean_buf(iM),
122 dst_var_buf(iM),
123 dst_count_buf(iM),
124 src_mean_buf[Number<src_offset>{}],
125 src_var_buf[Number<src_offset>{}],
126 src_count_buf[Number<src_offset>{}]);
127 });
128
129 if constexpr(GetActualVariance)
130 {
131 dst_var_buf(iM) = dst_var_buf[iM] / dst_count_buf[iM];
132 };
133 });
134 };
135};
136
137} // namespace ck
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
integral_constant< index_t, N > Number
Definition number.hpp:12
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
signed int int32_t
Definition stdint.h:123
__device__ void Update(T &mean, T &var, T x)
Definition threadwise_welford.hpp:31
static constexpr auto thread_mean_var_length_m
Definition threadwise_welford.hpp:24
__device__ constexpr ThreadwiseWelford()
Definition threadwise_welford.hpp:29
static constexpr auto mean_var_thread_desc_m
Definition threadwise_welford.hpp:20
static constexpr auto x_thread_desc_m_k
Definition threadwise_welford.hpp:19
__device__ void Run(const XBufferType &x_buf_m_k, MeanBufferType &mean_buf_m, VarBufferType &var_buf_m)
Definition threadwise_welford.hpp:51
static constexpr auto thread_x_length_m
Definition threadwise_welford.hpp:22
static constexpr auto thread_x_length_k
Definition threadwise_welford.hpp:23
Definition threadwise_welford.hpp:83
static __device__ void Run(const SrcMeanBufferType &src_mean_buf, const SrcVarBufferType &src_var_buf, const SrcCountBufferType &src_count_buf, DstMeanBufferType &dst_mean_buf, DstVarBufferType &dst_var_buf, DstCountBufferType &dst_count_buf)
Definition threadwise_welford.hpp:110
static constexpr auto src_length_k
Definition threadwise_welford.hpp:88
static __device__ void Merge(T &mean_a, T &var_a, int32_t &count_a, T mean_b, T var_b, int32_t count_b)
Definition threadwise_welford.hpp:94
static constexpr auto dst_thread_desc_m
Definition threadwise_welford.hpp:85
static constexpr auto dst_length_m
Definition threadwise_welford.hpp:89
static constexpr auto src_thread_desc_m_k
Definition threadwise_welford.hpp:84
static constexpr auto src_length_m
Definition threadwise_welford.hpp:87
Definition functional2.hpp:33