statically_indexed_array_multi_index.hpp Source File

statically_indexed_array_multi_index.hpp Source File#

Composable Kernel: statically_indexed_array_multi_index.hpp Source File
statically_indexed_array_multi_index.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#ifndef CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
5#define CK_STATICALLY_INDEXED_ARRAY_MULTI_INDEX_HPP
6
7#include "common_header.hpp"
9
10namespace ck {
11
12template <index_t N>
14
15template <typename... Xs>
16__host__ __device__ constexpr auto make_multi_index(Xs&&... xs)
17{
19}
20
21template <index_t NSize>
22__host__ __device__ constexpr auto make_zero_multi_index()
23{
24 return unpack([](auto... xs) { return make_multi_index(xs...); },
26}
27
28template <typename T>
29__host__ __device__ constexpr auto to_multi_index(const T& x)
30{
31 return unpack([](auto... ys) { return make_multi_index(ys...); }, x);
32}
33
34// Here should use MultiIndex<NSize>, instead of Tuple<Ys...>, although the former
35// is the alias of the latter. This is because compiler cannot infer the NSize if
36// using MultiIndex<NSize>
37// TODO: how to fix this?
38template <typename... Ys,
39 typename X,
41__host__ __device__ constexpr auto operator+=(Tuple<Ys...>& y, const X& x)
42{
43 static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
44 constexpr index_t NSize = sizeof...(Ys);
45 static_for<0, NSize, 1>{}([&](auto i) { y(i) += x[i]; });
46 return y;
47}
48
49template <typename... Ys,
50 typename X,
51 enable_if_t<!ck::is_integral<X>::value && !ck::is_floating_point<X>::value, bool> = false>
52__host__ __device__ constexpr auto operator-=(Tuple<Ys...>& y, const X& x)
53{
54 static_assert(X::Size() == sizeof...(Ys), "wrong! size not the same");
55 constexpr index_t NSize = sizeof...(Ys);
56 static_for<0, NSize, 1>{}([&](auto i) { y(i) -= x[i]; });
57 return y;
58}
59
60template <typename... Xs,
61 typename Y,
62 enable_if_t<!ck::is_integral<Y>::value && !ck::is_floating_point<Y>::value, bool> = false>
63__host__ __device__ constexpr auto operator+(const Tuple<Xs...>& x, const Y& y)
64{
65 static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
66 constexpr index_t NSize = sizeof...(Xs);
67
68 Tuple<Xs...> r;
69 static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] + y[i]; });
70 return r;
71}
72
73template <typename... Xs,
74 typename Y,
75 enable_if_t<!ck::is_integral<Y>::value && !ck::is_floating_point<Y>::value, bool> = false>
76__host__ __device__ constexpr auto operator-(const Tuple<Xs...>& x, const Y& y)
77{
78 static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
79 constexpr index_t NSize = sizeof...(Xs);
80
81 Tuple<Xs...> r;
82 static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] - y[i]; });
83 return r;
84}
85
86template <typename... Xs,
87 typename Y,
88 enable_if_t<!ck::is_integral<Y>::value && !ck::is_floating_point<Y>::value, bool> = false>
89__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, const Y& y)
90{
91 static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
92 constexpr index_t NSize = sizeof...(Xs);
93
94 Tuple<Xs...> r;
95 static_for<0, NSize, 1>{}([&](auto i) { r(i) = x[i] * y[i]; });
96 return r;
97}
98
99// MultiIndex = scalar * MultiIndex
100template <typename... Xs,
101 typename Y,
102 enable_if_t<ck::is_integral<Y>::value || ck::is_floating_point<Y>::value, bool> = false>
103__host__ __device__ constexpr auto operator*(Y a, const Tuple<Xs...>& x)
104{
105 constexpr index_t NSize = sizeof...(Xs);
106
107 Tuple<Xs...> r;
108 static_for<0, NSize, 1>{}([&](auto i) { r(i) = a * x[i]; });
109 return r;
110}
111
112// MultiIndex = MultiIndex * scalar
113template <typename... Xs,
114 typename Y,
115 enable_if_t<ck::is_integral<Y>::value || ck::is_floating_point<Y>::value, bool> = false>
116__host__ __device__ constexpr auto operator*(const Tuple<Xs...>& x, Y a)
117{
118 return a * x;
119}
120
121namespace mathext {
122
123template <typename... Xs>
124__host__ __device__ constexpr auto exp(const Tuple<Xs...>& x)
125{
126 constexpr index_t NSize = sizeof...(Xs);
127
128 Tuple<Xs...> r;
129 static_for<0, NSize, 1>{}([&](auto i) { r(i) = math::exp(x[i]); });
130 return r;
131}
132
133template <typename... Xs, typename Y>
134__host__ __device__ constexpr auto max(const Tuple<Xs...>& x, const Y& y)
135{
136 static_assert(Y::Size() == sizeof...(Xs), "wrong! size not the same");
137 constexpr index_t NSize = sizeof...(Xs);
138
139 Tuple<Xs...> r;
140 static_for<0, NSize, 1>{}([&](auto i) { r(i) = math::max(x[i], y[i]); });
141 return r;
142}
143
144} // namespace mathext
145
146template <typename... Xs>
147__host__ __device__ void print_multi_index(const Tuple<Xs...>& x)
148{
149 printf("{");
150 printf("MultiIndex, ");
151 printf("size %d,", index_t{sizeof...(Xs)});
152 static_for<0, sizeof...(Xs), 1>{}(
153 [&](auto i) { printf("%d ", static_cast<index_t>(x.At(i))); });
154 printf("}");
155}
156
157} // namespace ck
158#endif
__host__ T exp(T x)
Definition math_v2.hpp:391
__host__ __device__ constexpr T max(T x)
Definition utility/math.hpp:84
Definition statically_indexed_array_multi_index.hpp:121
__host__ __device__ constexpr auto exp(const Tuple< Xs... > &x)
Definition statically_indexed_array_multi_index.hpp:124
__host__ __device__ constexpr auto max(const Tuple< Xs... > &x, const Y &y)
Definition statically_indexed_array_multi_index.hpp:134
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition utility/statically_indexed_array.hpp:45
int32_t index_t
Definition ck.hpp:299
__host__ __device__ constexpr auto operator+=(MultiIndex< NSize > &y, const X &x)
Definition array_multi_index.hpp:34
__host__ __device__ constexpr auto make_statically_indexed_array()
Definition utility/statically_indexed_array.hpp:55
__host__ __device__ constexpr auto unpack(F &&f, X &&x)
Definition functional4.hpp:46
__host__ __device__ constexpr auto operator-=(MultiIndex< NSize > &y, const X &x)
Definition array_multi_index.hpp:42
__host__ __device__ constexpr auto operator-(const MultiIndex< NSize > &a, const T &b)
Definition array_multi_index.hpp:60
__host__ __device__ constexpr auto to_multi_index(const T &x)
Definition array_multi_index.hpp:28
__host__ __device__ constexpr auto operator+(const MultiIndex< NSize > &a, const T &b)
Definition array_multi_index.hpp:50
__host__ __device__ constexpr auto make_zero_multi_index()
Definition array_multi_index.hpp:21
__host__ __device__ constexpr auto operator*(const MultiIndex< NSize > &a, const T &b)
Definition array_multi_index.hpp:70
__host__ __device__ void print_multi_index(const Tuple< Xs... > &x)
Definition statically_indexed_array_multi_index.hpp:147
typename std::enable_if< B, T >::type enable_if_t
Definition enable_if.hpp:27
Array< index_t, N > MultiIndex
Definition array_multi_index.hpp:12
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
Definition utility/tuple.hpp:117
__host__ __device__ constexpr const auto & At(Number< I >) const
Definition utility/tuple.hpp:141
static constexpr value_type value
Definition utility/integral_constant.hpp:13
Definition functional2.hpp:33
typename sequence_gen< NSize, F >::type type
Definition utility/sequence.hpp:295