reference_transpose.hpp Source File

reference_transpose.hpp Source File#

Composable Kernel: reference_transpose.hpp Source File
reference_transpose.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8#include <thread>
9
10namespace ck_tile {
11
12template <typename ADataType, typename BDataType>
14{
15 ck_tile::index_t M = static_cast<ck_tile::index_t>(a.mDesc.get_lengths()[0]);
16 ck_tile::index_t N = static_cast<ck_tile::index_t>(a.mDesc.get_lengths()[1]);
17
18 // Ensure the b tensor is sized correctly for N x M
19 if(static_cast<ck_tile::index_t>(b.mDesc.get_lengths()[0]) != N ||
20 static_cast<ck_tile::index_t>(b.mDesc.get_lengths()[1]) != M)
21 {
22 throw std::runtime_error("Output tensor b has incorrect dimensions for transpose.");
23 }
24
25 auto f = [&](auto i, auto j) {
26 auto v_a = a(i, j);
28 };
29
30 make_ParallelTensorFunctor(f, M, N)(std::thread::hardware_concurrency());
31}
32
33} // namespace ck_tile
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition tile/host/host_tensor.hpp:329
void reference_transpose_elementwise(const HostTensor< ADataType > &a, HostTensor< BDataType > &b)
Definition reference_transpose.hpp:13
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition pointer.h:1517
const std::vector< std::size_t > & get_lengths() const
Definition tile/host/host_tensor.hpp:198
Definition tile/host/host_tensor.hpp:336
Descriptor mDesc
Definition tile/host/host_tensor.hpp:800