device_batchnorm_backward.hpp Source File

device_batchnorm_backward.hpp Source File#

Composable Kernel: device_batchnorm_backward.hpp Source File
device_batchnorm_backward.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <array>
7#include <memory>
8
9#include "ck/ck.hpp"
11
12namespace ck {
13namespace tensor_operation {
14namespace device {
15
16template <typename XDataType,
17 typename DxDataType,
18 typename DyDataType,
19 typename AccDataType,
20 typename ScaleDataType,
21 typename DscaleDbiasDataType,
22 typename MeanVarDataType,
23 typename DyElementwiseOp,
24 index_t Rank,
25 index_t NumBatchNormReduceDim>
27{
28 static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
29
30 virtual std::unique_ptr<BaseArgument>
31 MakeArgumentPointer(const std::array<index_t, Rank> xyLengths,
32 const std::array<index_t, Rank> xStrides,
33 const std::array<index_t, Rank> dyStrides,
34 const std::array<index_t, Rank> dxStrides,
35 const std::array<int, NumBatchNormReduceDim> reduceDims,
36 const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
37 const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
38 const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
39 const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
40 const void* p_x,
41 const void* p_dy,
42 const void* p_scale,
43 const void* p_savedMean,
44 const void* p_savedInvVar,
45 double epsilon,
46 const DyElementwiseOp dy_elementwise_op,
47 void* p_dx,
48 void* p_dscale,
49 void* p_dbias) = 0;
50
51 virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
52};
53
54template <typename XDataType,
55 typename DxDataType,
56 typename DyDataType,
57 typename AccDataType,
58 typename ScaleDataType,
59 typename DscaleDbiasDataType,
60 typename MeanVarDataType,
61 typename DyElementwiseOp,
62 index_t Rank,
63 index_t NumBatchNormReduceDim>
64using DeviceBatchNormBwdPtr = std::unique_ptr<DeviceBatchNormBwd<XDataType,
65 DxDataType,
66 DyDataType,
67 AccDataType,
68 ScaleDataType,
69 DscaleDbiasDataType,
70 MeanVarDataType,
71 DyElementwiseOp,
72 Rank,
73 NumBatchNormReduceDim>>;
74
75} // namespace device
76} // namespace tensor_operation
77} // namespace ck
Definition convolution_backward_data_specialization.hpp:8
std::unique_ptr< DeviceBatchNormBwd< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim > > DeviceBatchNormBwdPtr
Definition device_batchnorm_backward.hpp:64
Definition convolution_backward_data_specialization.hpp:7
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
Definition device_batchnorm_backward.hpp:27
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, Rank > xyLengths, const std::array< index_t, Rank > xStrides, const std::array< index_t, Rank > dyStrides, const std::array< index_t, Rank > dxStrides, const std::array< int, NumBatchNormReduceDim > reduceDims, const std::array< ck::index_t, NumInvariantDim > bnScaleBiasMeanVarLengths, const std::array< ck::index_t, NumInvariantDim > bnScaleStrides, const std::array< ck::index_t, NumInvariantDim > bnDscaleDbiasStrides, const std::array< ck::index_t, NumInvariantDim > bnMeanVarStrides, const void *p_x, const void *p_dy, const void *p_scale, const void *p_savedMean, const void *p_savedInvVar, double epsilon, const DyElementwiseOp dy_elementwise_op, void *p_dx, void *p_dscale, void *p_dbias)=0
static constexpr index_t NumInvariantDim
Definition device_batchnorm_backward.hpp:28