gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp File Reference

gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp File Reference#

Composable Kernel: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp File Reference
gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp File Reference

Go to the source code of this file.

Classes

struct  ck::GridwiseWelfordSecondHalfReduceFirstHalf< XDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, XYGridDesc_M_K, MeanVarGridDesc_M, MeanVarCountGridDesc_M_K, DscaleDbiasGridDesc_M_G, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyVectorDim, XSrcVectorSize, DySrcVectorSize, MeanVarSrcVectorSize >

Namespaces

namespace  ck

Functions

template<typename GridwiseWelfordSecondHalfReduceFirstHalf_, typename XDataType, typename DyDataType, typename AccDataType, typename ScaleDataType, typename DscaleDbiasDataType, typename MeanVarDataType, typename DyElementwiseOp, typename XYGridDesc_M_K, typename MeanVarGridDesc_M, typename MeanVarCountGridDesc_M_K, typename DscaleDbiasGridDesc_M_G>
__global__ void ck::kernel_welford_second_half_reduce_first_half (const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k, const MeanVarGridDesc_M mean_var_grid_desc_m, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const DscaleDbiasGridDesc_M_G dscale_dbias_grid_desc_m_g, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, index_t num_mean_var_count_k_block_tile_iteration, AccDataType epsilon, bool haveSavedMeanInvVar, const MeanVarDataType *const __restrict__ p_savedMean, const MeanVarDataType *const __restrict__ p_savedInvVar, const MeanVarDataType *const __restrict__ p_in_welford_mean, const MeanVarDataType *const __restrict__ p_in_welford_variance, const int32_t *const __restrict__ p_in_welford_count, const DyElementwiseOp dy_elementwise_op, MeanVarDataType *const __restrict__ p_out_welford_mean, MeanVarDataType *const __restrict__ p_out_welford_inv_variance, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, DscaleDbiasDataType *const __restrict__ p_reduce_dscale, DscaleDbiasDataType *const __restrict__ p_reduce_dbias)