device_multiple_reduce.hpp Source File

device_multiple_reduce.hpp Source File#

Composable Kernel: device_multiple_reduce.hpp Source File
device_multiple_reduce.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <vector>
7#include <memory>
8#include <array>
9#include <iostream>
10
11#include "ck/ck.hpp"
14
15namespace ck {
16namespace tensor_operation {
17namespace device {
18
19template <index_t Rank,
20 index_t NumReduceDim,
21 index_t NumReduction,
22 typename InElementwiseOperationTuple,
23 typename AccElementwiseOperationTuple>
25{
26 static constexpr index_t NumInputDim = Rank;
27 static constexpr index_t NumOutputDim = (Rank - NumReduceDim > 1) ? Rank - NumReduceDim : 1;
28
29 virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
30 const std::array<index_t, NumInputDim> inLengths,
31 const std::array<index_t, NumInputDim> inStrides,
32 const std::array<index_t, NumOutputDim> outLengths,
33 const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStrides,
34 const std::array<int, NumReduceDim> reduceDims,
35 const std::array<double, NumReduction> alphas,
36 const std::array<double, NumReduction> betas,
37 const void* in_dev,
38 const std::array<void*, NumReduction> out_dev_buffers,
39 const InElementwiseOperationTuple in_elementwise_op_tuple,
40 const AccElementwiseOperationTuple acc_elementwise_op_tuple) = 0;
41
42 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
43};
44
45template <index_t Rank,
46 index_t NumReduceDim,
47 index_t NumReduction,
48 typename InElementwiseOperationTuple,
49 typename AccElementwiseOperationTuple>
50using DeviceMultipleReducePtr = std::unique_ptr<DeviceMultipleReduce<Rank,
51 NumReduceDim,
52 NumReduction,
53 InElementwiseOperationTuple,
54 AccElementwiseOperationTuple>>;
55
56} // namespace device
57} // namespace tensor_operation
58} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
std::unique_ptr< DeviceMultipleReduce< Rank, NumReduceDim, NumReduction, InElementwiseOperationTuple, AccElementwiseOperationTuple > > DeviceMultipleReducePtr
Definition device_multiple_reduce.hpp:50
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_multiple_reduce.hpp:25
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, NumInputDim > inLengths, const std::array< index_t, NumInputDim > inStrides, const std::array< index_t, NumOutputDim > outLengths, const std::array< std::array< index_t, NumOutputDim >, NumReduction > outStrides, const std::array< int, NumReduceDim > reduceDims, const std::array< double, NumReduction > alphas, const std::array< double, NumReduction > betas, const void *in_dev, const std::array< void *, NumReduction > out_dev_buffers, const InElementwiseOperationTuple in_elementwise_op_tuple, const AccElementwiseOperationTuple acc_elementwise_op_tuple)=0
static constexpr index_t NumInputDim
Definition device_multiple_reduce.hpp:26
static constexpr index_t NumOutputDim
Definition device_multiple_reduce.hpp:27
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0