device_gemm_splitk.hpp Source File

device_gemm_splitk.hpp Source File#

Composable Kernel: device_gemm_splitk.hpp Source File
device_gemm_splitk.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <iostream>
7#include <vector>
8
9#include "device_base.hpp"
10
11namespace ck {
12namespace tensor_operation {
13namespace device {
14
15template <typename ALayout,
16 typename BLayout,
17 typename CLayout,
18 typename ADataType,
19 typename BDataType,
20 typename CDataType,
21 typename AElementwiseOperation,
22 typename BElementwiseOperation,
23 typename CElementwiseOperation,
24 typename ComputeType = CDataType>
26{
27 virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
28 const void* p_b,
29 void* p_c,
33 ck::index_t StrideA,
34 ck::index_t StrideB,
35 ck::index_t StrideC,
36 AElementwiseOperation a_element_op,
37 BElementwiseOperation b_element_op,
38 CElementwiseOperation c_element_op,
39 ck::index_t KBatch) = 0;
40
41 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
42};
43
44template <typename ALayout,
45 typename BLayout,
46 typename CLayout,
47 typename ADataType,
48 typename BDataType,
49 typename CDataType,
50 typename AElementwiseOperation,
51 typename BElementwiseOperation,
52 typename CElementwiseOperation,
53 typename ComputeType = CDataType>
54using DeviceGemmSplitKPtr = std::unique_ptr<DeviceGemmSplitK<ALayout,
55 BLayout,
56 CLayout,
57 ADataType,
58 BDataType,
59 CDataType,
60 AElementwiseOperation,
61 BElementwiseOperation,
62 CElementwiseOperation,
63 ComputeType>>;
64
65} // namespace device
66} // namespace tensor_operation
67} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
std::unique_ptr< DeviceGemmSplitK< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, ComputeType > > DeviceGemmSplitKPtr
Definition device_gemm_splitk.hpp:54
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_gemm_splitk.hpp:26
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, ck::index_t KBatch)=0