reference_permute.hpp Source File

reference_permute.hpp Source File#

Composable Kernel: reference_permute.hpp Source File
reference_permute.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8#include <thread>
9#include <numeric>
10#include <functional>
11
12namespace ck_tile {
13
14/*
15 this will do permute + contiguous like functionality in pytorch
16*/
17template <typename DataType>
18CK_TILE_HOST void
19reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::vector<index_t> perm)
20{
21 const auto x_len = x.mDesc.get_lengths();
22 const auto y_len = y.mDesc.get_lengths();
23 assert(x_len.size() == y_len.size());
24 index_t rank = x_len.size();
25 const auto x_elm = std::accumulate(x_len.begin(), x_len.end(), 1, std::multiplies<index_t>());
26 const auto y_elm = std::accumulate(y_len.begin(), y_len.end(), 1, std::multiplies<index_t>());
27 assert(x_elm == y_elm);
28 (void)y_elm;
29
30 auto f = [&](auto i_element) {
31 std::vector<size_t> y_coord = [&]() {
32 std::vector<size_t> tmp(rank, 0);
33 size_t r = i_element;
34 for(index_t i = rank - 1; i >= 0; i--)
35 {
36 tmp[i] = r % y_len[i];
37 r = r / y_len[i];
38 }
39 return tmp;
40 }();
41
42 std::vector<size_t> x_coord = [&]() {
43 std::vector<size_t> tmp(rank, 0);
44 for(index_t i = 0; i < rank; i++)
45 {
46 tmp[perm[i]] = y_coord[i];
47 }
48 return tmp;
49 }();
50
51 // do permute
52 y(y_coord) = x(x_coord);
53 };
54
55 make_ParallelTensorFunctor(f, x_elm)(std::thread::hardware_concurrency());
56}
57
58template <typename DataType>
59CK_TILE_HOST auto reference_permute(const HostTensor<DataType>& x, std::vector<index_t> perm)
60{
61 auto x_shape = x.get_lengths();
62 ck_tile::index_t rank = perm.size();
63 std::vector<ck_tile::index_t> y_shape = [&]() {
64 std::vector<ck_tile::index_t> tmp(rank, 0);
65 for(int i = 0; i < static_cast<int>(rank); i++)
66 {
67 tmp[i] = x_shape[perm[i]];
68 }
69 return tmp;
70 }();
71
72 HostTensor<DataType> y(y_shape);
73 reference_permute(x, y, perm);
74 return y;
75}
76} // namespace ck_tile
#define CK_TILE_HOST
Definition config.hpp:40
__host__ __device__ constexpr auto rank(const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition layout_utils.hpp:310
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition tile/host/host_tensor.hpp:329
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST void reference_permute(const HostTensor< DataType > &x, HostTensor< DataType > &y, std::vector< index_t > perm)
Definition reference_permute.hpp:19
const std::vector< std::size_t > & get_lengths() const
Definition tile/host/host_tensor.hpp:198
Definition tile/host/host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition tile/host/host_tensor.hpp:390
Descriptor mDesc
Definition tile/host/host_tensor.hpp:800