reference_batched_rotary_position_embedding.hpp Source File

reference_batched_rotary_position_embedding.hpp Source File#

Composable Kernel: reference_batched_rotary_position_embedding.hpp Source File
reference_batched_rotary_position_embedding.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9#include <cassert>
10#include <thread>
11
12namespace ck_tile {
13
14template <typename DataType, typename ComputeDataType = float>
16 const HostTensor<DataType>& cos_sd,
17 const HostTensor<DataType>& sin_sd,
18 bool interleaved,
19 HostTensor<DataType>& output_bsd,
20 bool use_1_row_sin_cos = false)
21{
22 assert(cos_sd.get_num_of_dimension() == 2 && sin_sd.get_num_of_dimension() == 2);
23 assert(cos_sd.get_length(0) == sin_sd.get_length(0) &&
24 cos_sd.get_length(1) == sin_sd.get_length(1));
25
26 const index_t rotary_dim = cos_sd.get_length(1) * 2;
27 assert(static_cast<std::size_t>(rotary_dim) <= input_bsd.get_length(2));
28
29 output_bsd.ForEach([&](auto& self, auto i) {
30 const index_t i_d = i[2];
31 if(rotary_dim <= i_d)
32 {
33 self(i) = input_bsd(i);
34 return;
35 }
36 assert(i_d < rotary_dim);
37
38 const index_t i_s = i[1];
39 const index_t i_s_cos_sin = (use_1_row_sin_cos ? 0 : i_s);
40
41 const ComputeDataType cos = type_convert<ComputeDataType>(
42 interleaved ? cos_sd(i_s_cos_sin, i_d / 2)
43 : cos_sd(i_s_cos_sin, i_d % cos_sd.get_length(1)));
44 const ComputeDataType sin = type_convert<ComputeDataType>(
45 interleaved ? sin_sd(i_s_cos_sin, i_d / 2)
46 : sin_sd(i_s_cos_sin, i_d % sin_sd.get_length(1)));
47
48 const ComputeDataType half_rotated_input = [&] {
49 const index_t i_b = i[0];
50
51 if(interleaved)
52 {
53 const bool is_even = (i_d % 2 == 0);
54 const index_t pos = i_d + (is_even ? 1 : -1);
55 const ComputeDataType sign = (is_even ? -1 : 1);
56 return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
57 }
58 else
59 {
60 const index_t half_rdim = (rotary_dim / 2);
61 const index_t pos = (i_d + half_rdim) % rotary_dim;
62 const ComputeDataType sign = (pos < half_rdim ? 1 : -1);
63 return sign * type_convert<ComputeDataType>(input_bsd(i_b, i_s, pos));
64 }
65 }();
66 ComputeDataType result =
67 type_convert<ComputeDataType>(input_bsd(i)) * cos + half_rotated_input * sin;
68
69 self(i) = type_convert<DataType>(result);
70 });
71}
72
73} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST T cos(T x)
Definition tile/core/numeric/math.hpp:752
CK_TILE_HOST T sin(T x)
Definition tile/core/numeric/math.hpp:698
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_HOST void reference_batched_rotary_position_embedding(const HostTensor< DataType > &input_bsd, const HostTensor< DataType > &cos_sd, const HostTensor< DataType > &sin_sd, bool interleaved, HostTensor< DataType > &output_bsd, bool use_1_row_sin_cos=false)
Definition reference_batched_rotary_position_embedding.hpp:15
Definition tile/host/host_tensor.hpp:336
void ForEach(F &&f)
Definition tile/host/host_tensor.hpp:437
std::size_t get_num_of_dimension() const
Definition tile/host/host_tensor.hpp:396
std::size_t get_length(std::size_t dim) const
Definition tile/host/host_tensor.hpp:388