gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp Source File

gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp Source File#

Composable Kernel: gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp Source File
gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
12
13namespace ck {
14
15template <typename GridwiseWelfordSecondHalfBatchNormForwardFinal_,
16 typename XDataType,
17 typename YDataType,
18 typename AccDataType,
19 typename ScaleDataType,
20 typename BiasDataType,
21 typename MeanVarDataType,
22 typename YElementwiseOp,
23 typename XYGridDesc_M_K,
24 typename MeanVarCountGridDesc_M_K,
25 typename ScaleBiasGridDesc_M,
26 typename MeanVarGridDesc_M>
28 const XYGridDesc_M_K x_grid_desc_m_k,
29 const XYGridDesc_M_K y_grid_desc_m_k,
30 const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
31 const ScaleBiasGridDesc_M scale_grid_desc_m,
32 const ScaleBiasGridDesc_M bias_grid_desc_m,
33 const MeanVarGridDesc_M mean_var_grid_desc_m,
34 index_t blkgroup_size,
35 index_t num_xy_k_block_tile_iteration,
36 AccDataType epsilon,
37 const MeanVarDataType* const __restrict__ p_in_welford_mean,
38 const MeanVarDataType* const __restrict__ p_in_welford_variance,
39 const int32_t* const __restrict__ p_in_welford_count,
40 const XDataType* const __restrict__ p_x,
41 const ScaleDataType* const __restrict__ p_scale,
42 const BiasDataType* const __restrict__ p_bias,
43 const YElementwiseOp y_elementwise_op,
44 YDataType* const __restrict__ p_y,
45 bool updateMovingAverage,
46 AccDataType averageFactor,
47 MeanVarDataType* const __restrict__ resultRunningMean,
48 MeanVarDataType* const __restrict__ resultRunningVariance,
49 bool saveMeanInvVariance,
50 MeanVarDataType* const __restrict__ resultSaveMean,
51 MeanVarDataType* const __restrict__ resultSaveInvVariance)
52{
53 GridwiseWelfordSecondHalfBatchNormForwardFinal_::Run(x_grid_desc_m_k,
54 y_grid_desc_m_k,
55 mean_var_count_grid_desc_m_k,
56 scale_grid_desc_m,
57 bias_grid_desc_m,
58 mean_var_grid_desc_m,
59 blkgroup_size,
60 num_xy_k_block_tile_iteration,
61 epsilon,
62 p_in_welford_mean,
63 p_in_welford_variance,
64 p_in_welford_count,
65 p_x,
66 p_scale,
67 p_bias,
68 y_elementwise_op,
69 p_y,
70 updateMovingAverage,
71 averageFactor,
72 resultRunningMean,
73 resultRunningVariance,
74 saveMeanInvVariance,
75 resultSaveMean,
76 resultSaveInvVariance);
77};
78
79template <typename XDataType,
80 typename YDataType,
81 typename AccDataType,
82 typename ScaleDataType,
83 typename BiasDataType,
84 typename MeanVarDataType,
85 typename YElementwiseOp,
86 typename XYGridDesc_M_K,
87 typename MeanVarCountGridDesc_M_K,
88 typename ScaleBiasGridDesc_M,
89 typename MeanVarGridDesc_M,
90 index_t BlockSize,
91 index_t MThreadClusterSize,
92 index_t KThreadClusterSize,
93 index_t MThreadSliceSize,
94 index_t KThreadSliceSize,
95 index_t XSrcYDstVectorDim,
96 index_t XSrcVectorSize,
97 index_t YDstVectorSize,
98 index_t ScaleSrcVectorSize,
99 index_t BiasSrcVectorSize,
100 index_t MeanVarSrcDstVectorSize>
102{
103 static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
104 (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
105 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
106
107 static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
108 (XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
109 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
110
111 static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0);
112
114
117
120
121 static constexpr auto thread_cluster_desc =
123
128
131
133 BlockSize,
136
138
139 static constexpr auto I0 = Number<0>{};
140 static constexpr auto I1 = Number<1>{};
141
142 static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
143 static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
144
145 __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
146 const XYGridDesc_M_K& y_grid_desc_m_k,
147 const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
148 const ScaleBiasGridDesc_M& scale_grid_desc_m,
149 const ScaleBiasGridDesc_M& bias_grid_desc_m,
150 const MeanVarGridDesc_M& mean_var_grid_desc_m,
151 index_t blkgroup_size,
152 index_t num_xy_k_block_tile_iteration,
153 AccDataType epsilon,
154 const MeanVarDataType* const __restrict__ p_in_welford_mean,
155 const MeanVarDataType* const __restrict__ p_in_welford_variance,
156 const int32_t* const __restrict__ p_in_welford_count,
157 const XDataType* const __restrict__ p_x,
158 const ScaleDataType* const __restrict__ p_scale,
159 const BiasDataType* const __restrict__ p_bias,
160 const YElementwiseOp y_elementwise_op,
161 YDataType* const __restrict__ p_y,
162 bool updateMovingAverage,
163 AccDataType averageFactor,
164 MeanVarDataType* const __restrict__ resultRunningMean,
165 MeanVarDataType* const __restrict__ resultRunningVariance,
166 bool saveMeanInvVariance,
167 MeanVarDataType* const __restrict__ resultSaveMean,
168 MeanVarDataType* const __restrict__ resultSaveInvVariance)
169
170 {
171 using ck::math::sqrt;
172
174 in_welford_mean_thread_buf;
176 in_welford_var_thread_buf;
178 in_welford_count_thread_buf;
179
181 welford_mean_thread_buf;
183 welford_var_thread_buf;
185 welford_count_thread_buf;
186
188 x_thread_buf;
190 y_thread_buf;
191
194
195 const index_t thread_local_id = get_thread_local_1d_id();
196 const index_t block_global_id = get_block_1d_id();
197 const index_t blkgroup_id = block_global_id / blkgroup_size;
198 const index_t block_local_id = block_global_id % blkgroup_size;
199
200 const auto thread_cluster_idx =
201 thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
202
203 const auto thread_m_cluster_id = thread_cluster_idx[I0];
204 const auto thread_k_cluster_id = thread_cluster_idx[I1];
205
206 using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
207 using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
208 using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
209 constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
211 constexpr auto thread_buffer_desc_m =
213 constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
215
216 auto threadwise_mean_var_load_m_k =
217 ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
218 AccDataType,
219 MeanVarCountGridDesc_M_K,
220 decltype(thread_buffer_desc_m_1),
221 ThreadBufferLengths_M_1,
223 0,
224 1,
225 1,
226 true>(
227 mean_var_count_grid_desc_m_k,
228 make_multi_index(blkgroup_id * M_BlockTileSize +
229 thread_m_cluster_id * MThreadSliceSize,
230 thread_k_cluster_id * 1));
231
232 auto threadwise_count_load_m_k =
234 int32_t,
235 MeanVarCountGridDesc_M_K,
236 decltype(thread_buffer_desc_m_1),
237 ThreadBufferLengths_M_1,
239 0,
240 1,
241 1,
242 true>(
243 mean_var_count_grid_desc_m_k,
244 make_multi_index(blkgroup_id * M_BlockTileSize +
245 thread_m_cluster_id * MThreadSliceSize,
246 thread_k_cluster_id * 1));
247
248 const auto welford_mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
249 p_in_welford_mean, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
250
251 const auto welford_var_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
252 p_in_welford_variance, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
253
254 const auto welford_count_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
255 p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
256
257 // Step 1: do final welford reduction to get mean and variance
258
260 welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
261 welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
262 welford_count_thread_buf(I) = 0;
263 });
264
265 constexpr auto mean_var_count_thread_copy_step_m_k =
266 make_multi_index(0, KThreadClusterSize);
267
268 int32_t reducedSize = 0;
269 while(reducedSize < blkgroup_size)
270 {
271 threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
272 welford_mean_global_val_buf,
273 thread_buffer_desc_m_1,
274 make_tuple(I0, I0),
275 in_welford_mean_thread_buf);
276
277 threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
278 welford_var_global_val_buf,
279 thread_buffer_desc_m_1,
280 make_tuple(I0, I0),
281 in_welford_var_thread_buf);
282
283 threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k,
284 welford_count_global_val_buf,
285 thread_buffer_desc_m_1,
286 make_tuple(I0, I0),
287 in_welford_count_thread_buf);
288
289 ThreadwiseWelford::Run(in_welford_mean_thread_buf,
290 in_welford_var_thread_buf,
291 in_welford_count_thread_buf,
292 welford_mean_thread_buf,
293 welford_var_thread_buf,
294 welford_count_thread_buf);
295
296 reducedSize += KThreadClusterSize;
297
298 threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
299 mean_var_count_thread_copy_step_m_k);
300 threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
301 mean_var_count_thread_copy_step_m_k);
302 }
303
305 if constexpr(I > 0)
307
309 welford_mean_thread_buf(I), welford_var_thread_buf(I), welford_count_thread_buf(I));
310 });
311
312 // Step 2: do normalization and output y
313
314 const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration;
315
316 auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
317 AccDataType,
318 XYGridDesc_M_K,
319 decltype(thread_buffer_desc_m_k),
320 ThreadBufferLengths_M_K,
322 XSrcYDstVectorDim,
323 XSrcVectorSize,
324 1,
325 true>(
326 x_grid_desc_m_k,
327 make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
328 workSizePerBlock * block_local_id +
329 thread_k_cluster_id * KThreadSliceSize));
330
331 auto threadwise_y_store =
333 YDataType,
334 decltype(thread_buffer_desc_m_k),
335 XYGridDesc_M_K,
336 YElementwiseOp,
337 ThreadBufferLengths_M_K,
339 XSrcYDstVectorDim,
340 YDstVectorSize,
342 1,
343 true>(
344 y_grid_desc_m_k,
346 blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
347 workSizePerBlock * block_local_id + thread_k_cluster_id * KThreadSliceSize),
348 y_elementwise_op);
349
350 auto threadwise_scale_load =
352 AccDataType,
353 ScaleBiasGridDesc_M,
354 decltype(thread_buffer_desc_m),
355 ThreadBufferLengths_M,
357 0,
358 ScaleSrcVectorSize,
359 1,
360 true>(
361 scale_grid_desc_m,
362 make_multi_index(blkgroup_id * M_BlockTileSize +
363 thread_m_cluster_id * MThreadSliceSize));
364
365 auto threadwise_bias_load = ThreadwiseTensorSliceTransfer_v2<BiasDataType,
366 AccDataType,
367 ScaleBiasGridDesc_M,
368 decltype(thread_buffer_desc_m),
369 ThreadBufferLengths_M,
371 0,
372 BiasSrcVectorSize,
373 1,
374 true>(
375 bias_grid_desc_m,
376 make_multi_index(blkgroup_id * M_BlockTileSize +
377 thread_m_cluster_id * MThreadSliceSize));
378
379 const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
380 p_x, x_grid_desc_m_k.GetElementSpaceSize());
381
382 const auto scale_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
383 p_scale, scale_grid_desc_m.GetElementSpaceSize());
384
385 const auto bias_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
386 p_bias, bias_grid_desc_m.GetElementSpaceSize());
387
388 auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
389 p_y, y_grid_desc_m_k.GetElementSpaceSize());
390
391 threadwise_scale_load.Run(scale_grid_desc_m,
392 scale_global_val_buf,
393 thread_buffer_desc_m,
394 make_tuple(I0),
395 scale_thread_buf);
396
397 threadwise_bias_load.Run(bias_grid_desc_m,
398 bias_global_val_buf,
399 thread_buffer_desc_m,
400 make_tuple(I0),
401 bias_thread_buf);
402
403 constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize);
404
405 for(index_t workTiles = 0; workTiles < num_xy_k_block_tile_iteration; ++workTiles)
406 {
407 threadwise_x_load.Run(x_grid_desc_m_k,
408 x_global_val_buf,
409 thread_buffer_desc_m_k,
410 make_tuple(I0, I0),
411 x_thread_buf);
412
414 AccDataType multiplier =
415 scale_thread_buf[iM] / sqrt(welford_var_thread_buf[iM] + epsilon);
416
417 AccDataType fused_mean_bias =
418 bias_thread_buf[iM] - welford_mean_thread_buf[iM] * multiplier;
419
421 constexpr auto offset =
422 thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
423
424 y_thread_buf(Number<offset>{}) =
425 x_thread_buf[Number<offset>{}] * multiplier + fused_mean_bias;
426 });
427 });
428
429 threadwise_y_store.Run(thread_buffer_desc_m_k,
430 make_tuple(I0, I0),
431 y_thread_buf,
432 y_grid_desc_m_k,
433 y_global_val_buf);
434
435 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k);
436 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, xy_thread_copy_step_m_k);
437 }
438
439 // Step 3: update the moving average of mean and variance (optional)
440
441 if(updateMovingAverage && block_local_id == 0 && thread_k_cluster_id == 0)
442 {
444 running_mean_thread_buf;
446 running_var_thread_buf;
447
448 auto running_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
449 resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize());
450
451 auto running_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
452 resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize());
453
454 auto threadwise_mean_var_load_m =
455 ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
456 AccDataType,
457 MeanVarGridDesc_M,
458 decltype(thread_buffer_desc_m),
459 ThreadBufferLengths_M,
461 0,
462 MeanVarSrcDstVectorSize,
463 1,
464 true>(
465 mean_var_grid_desc_m,
466 make_multi_index(blkgroup_id * M_BlockTileSize +
467 thread_m_cluster_id * MThreadSliceSize));
468
469 threadwise_mean_var_load_m.Run(mean_var_grid_desc_m,
470 running_mean_global_buf,
471 thread_buffer_desc_m,
472 make_tuple(I0),
473 running_mean_thread_buf);
474
475 threadwise_mean_var_load_m.Run(mean_var_grid_desc_m,
476 running_var_global_buf,
477 thread_buffer_desc_m,
478 make_tuple(I0),
479 running_var_thread_buf);
480
481 AccDataType oneMinusAverageFactor = type_convert<AccDataType>(1.0) - averageFactor;
482
484 running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor +
485 welford_mean_thread_buf[I] * averageFactor;
486 running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor +
487 welford_var_thread_buf[I] * averageFactor;
488 });
489
490 auto threadwise_mean_var_store =
492 MeanVarDataType,
493 decltype(thread_buffer_desc_m),
494 MeanVarGridDesc_M,
496 ThreadBufferLengths_M,
498 0,
499 MeanVarSrcDstVectorSize,
501 1,
502 true>(
503 mean_var_grid_desc_m,
504 make_multi_index(blkgroup_id * M_BlockTileSize +
505 thread_m_cluster_id * MThreadSliceSize),
506 PassThroughOp{});
507
508 threadwise_mean_var_store.Run(thread_buffer_desc_m,
509 make_tuple(I0),
510 running_mean_thread_buf,
511 mean_var_grid_desc_m,
512 running_mean_global_buf);
513
514 threadwise_mean_var_store.Run(thread_buffer_desc_m,
515 make_tuple(I0),
516 running_var_thread_buf,
517 mean_var_grid_desc_m,
518 running_var_global_buf);
519 };
520
521 // Step 4: save mean and inv-variance (optional)
522
523 if(saveMeanInvVariance && block_local_id == 0 && thread_k_cluster_id == 0)
524 {
525 auto result_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
526 resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize());
527
528 auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
529 resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
530
531 // calculate inv-variance as 1/sqrt(epsilon+variance)
533 welford_var_thread_buf(I) =
534 type_convert<AccDataType>(1.0f) / sqrt(epsilon + welford_var_thread_buf[I]);
535 });
536
537 auto threadwise_mean_inv_var_store =
539 MeanVarDataType,
540 decltype(thread_buffer_desc_m),
541 MeanVarGridDesc_M,
543 ThreadBufferLengths_M,
545 0,
546 MeanVarSrcDstVectorSize,
548 1,
549 true>(
550 mean_var_grid_desc_m,
551 make_multi_index(blkgroup_id * M_BlockTileSize +
552 thread_m_cluster_id * MThreadSliceSize),
553 PassThroughOp{});
554
555 threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
556 make_tuple(I0),
557 welford_mean_thread_buf,
558 mean_var_grid_desc_m,
559 result_mean_global_buf);
560
561 threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
562 make_tuple(I0),
563 welford_var_thread_buf,
564 mean_var_grid_desc_m,
565 result_inv_var_global_buf);
566 };
567 }
568};
569
570} // namespace ck
Definition ck.hpp:268
__host__ __device__ constexpr auto make_multi_index(Xs &&... xs)
Definition array_multi_index.hpp:15
int32_t index_t
Definition ck.hpp:299
__global__ void kernel_welford_second_half_batchnorm_forward_final(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K y_grid_desc_m_k, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, AccDataType epsilon, const MeanVarDataType *const __restrict__ p_in_welford_mean, const MeanVarDataType *const __restrict__ p_in_welford_variance, const int32_t *const __restrict__ p_in_welford_count, const XDataType *const __restrict__ p_x, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)
Definition gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:27
@ Set
Definition ck.hpp:278
__host__ __device__ constexpr auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition tensor_description/cluster_descriptor.hpp:13
integral_constant< index_t, N > Number
Definition number.hpp:12
__device__ index_t get_block_1d_id()
Definition get_id.hpp:47
__host__ __device__ constexpr Y type_convert(X x)
Definition utility/type_convert.hpp:98
__host__ __device__ constexpr auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition tensor_descriptor_helper.hpp:101
__host__ __device__ constexpr auto make_tuple(Xs &&... xs)
Definition utility/tuple.hpp:211
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
__device__ void block_sync_lds()
Definition synchronization.hpp:16
__host__ __device__ constexpr auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition dynamic_buffer.hpp:472
signed int int32_t
Definition stdint.h:123
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition blockwise_welford.hpp:51
Definition gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:102
static constexpr index_t K_BlockTileSize
Definition gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:143
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:126
static constexpr auto I0
Definition gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:139
decltype(make_naive_tensor_descriptor_packed( make_tuple(Number< MThreadSliceSize >{}, Number< 1 >{}))) ThreadReduceSrcDesc_M_1
Definition gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:124
tensor_operation::element_wise::PassThrough PassThroughOp
Definition gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:137
static constexpr auto thread_cluster_desc
Definition gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:121
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:118
static constexpr auto I1
Definition gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:140
static constexpr index_t M_BlockTileSize
Definition gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:142
static __device__ void Run(const XYGridDesc_M_K &x_grid_desc_m_k, const XYGridDesc_M_K &y_grid_desc_m_k, const MeanVarCountGridDesc_M_K &mean_var_count_grid_desc_m_k, const ScaleBiasGridDesc_M &scale_grid_desc_m, const ScaleBiasGridDesc_M &bias_grid_desc_m, const MeanVarGridDesc_M &mean_var_grid_desc_m, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, AccDataType epsilon, const MeanVarDataType *const __restrict__ p_in_welford_mean, const MeanVarDataType *const __restrict__ p_in_welford_variance, const int32_t *const __restrict__ p_in_welford_count, const XDataType *const __restrict__ p_x, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)
Definition gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:145
BlockwiseWelford< AccDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder > BlockwiseWelford
Definition gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:132
static constexpr bool reorder_thread_cluster
Definition gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:111
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:113
ThreadwiseWelfordMerge< AccDataType, ThreadReduceSrcDesc_M_1, ThreadReduceDstDesc_M > ThreadwiseWelford
Definition gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:129
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:115
Definition utility/sequence.hpp:43
Definition static_buffer.hpp:16
Definition threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition threadwise_tensor_slice_transfer.hpp:234
Definition threadwise_welford.hpp:83
static __device__ void Run(const SrcMeanBufferType &src_mean_buf, const SrcVarBufferType &src_var_buf, const SrcCountBufferType &src_count_buf, DstMeanBufferType &dst_mean_buf, DstVarBufferType &dst_var_buf, DstCountBufferType &dst_count_buf)
Definition threadwise_welford.hpp:110
Definition utility/functional.hpp:100
Definition functional2.hpp:33
Definition tensor_operation/gpu/element/unary_element_wise_operation.hpp:340