block_dropout.hpp Source File

block_dropout.hpp Source File#

Composable Kernel: block_dropout.hpp Source File
block_dropout.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include "ck_tile/core.hpp"
8
9namespace ck_tile {
10
11// BlockDropoutBwd and BlockDropout (fwd) support two warp gemm tile sizes: 32x32 (MFMA only) and
12// 16x16 (MFMA and WMMA). Even if fwd and bwd use different tile sizes, generated random
13// numbers will be the same, they are also the same for MFMA (on CDNA), WMMA (on RDNA), or host
14// (for verification, see ck_tile/host/reference/reference_batched_dropout_randval.hpp).
15//
16// The (row, col) coordinate of the current 32x32 tile in the P matrix determines a subsequence of
17// random numbers (ph_subsequence).
18// The (batch, head, 0..63) coordinate determines an offset in the subsequence (ph_head_offset and
19// ph_offset).
20// This means that subsequences are non-overlapping, reproducible and independent of mask or window.
21//
22// There are 3 modes (all produce the same results):
23// * For 32x32 MFMA tile each of 64 lanes generates 4 * 32 bits or 16 bytes, so one warp generates
24// the entire 32x32 tile (64 * 16 = 32 * 32).
25// * For 16x16 MFMA tile one warp generates 1/4 of the 32x32 tile ((16 * 16) / (64 * 16) = 1/4), 4
26// warps generate the same 64 * 16 random bytes and each uses its own quarter. If kMPerBlock >
27// MWarp * WG::kM one warp can generate two 16x16 tiles (MIterPerWarp = 2) so fewer instructions
28// are needed for generating a 32x32 tile.
29// * For 16x16 WMMA tile one warp generates 1/2 of the 32x32 tile ((16 * 16) / (32 * 16) = 1/2), 2
30// warps generate the same 64 * 16 random bytes and each uses its own half. If kMPerBlock > MWarp *
31// WG::kM one warp can generate two 16x16 tiles.
32
33namespace detail {
34// The number of Philox 4x32 results required to fill 32x32 tile of 8-bit values
35constexpr index_t philox_per_tile = 64;
36} // namespace detail
37
39{
40 template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
41 CK_TILE_HOST_DEVICE static constexpr auto
42 MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
43 index_t seqlen_qk_start)
44 {
45 (void)randval_dram_block_window_tmp;
46 (void)seqlen_qk_start;
47
49 }
50};
51
53{
55 index_t i_head,
56 index_t nheads,
57 unsigned long long seed,
58 unsigned long long offset,
59 float rp_undrop_,
60 uint8_t p_undrop_in_uint8_t_,
61 bool is_store_randval_)
63 ph_head_offset(amd_wave_read_first_lane(offset + (i_batch * nheads + i_head) *
64 detail::philox_per_tile)),
65 rp_undrop(rp_undrop_),
66 p_undrop_in_uint8_t(p_undrop_in_uint8_t_),
67 is_store_randval(is_store_randval_)
68 {
69 }
70
71 template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
72 CK_TILE_HOST_DEVICE static constexpr auto
73 MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
74 index_t seqlen_qk_start)
75 {
76 constexpr auto config =
77 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
78 using WG = remove_cvref_t<decltype(config.template at<0>())>;
79 constexpr bool IsWG32 = WG::kM == 32;
80 constexpr index_t MWarp = config.template at<1>();
81 constexpr index_t NWarp = config.template at<2>();
83 constexpr index_t kMPerBlock = BlockGemmShape::kM;
84 constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
85 constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
86 constexpr index_t kNPerStep = NWarp * WG::kN;
87
88 const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
89 auto randval_dram_window = [&]() {
90 if constexpr(IsFwd)
91 {
92 return make_tile_window(
93 randval_dram_block_window_tmp.get_bottom_tensor_view(),
95 {block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
96 }
97 else
98 {
99 return make_tile_window(
100 randval_dram_block_window_tmp.get_bottom_tensor_view(),
102 {seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
103 }
104 }();
105
106 return randval_dram_window;
107 }
108
109 template <typename BlockGemm>
111 {
112 constexpr auto config =
113 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
114 using WG = remove_cvref_t<decltype(config.template at<0>())>;
115 constexpr bool IsWG32 = WG::kM == 32;
116 constexpr index_t MWarp = config.template at<1>();
117 constexpr index_t NWarp = config.template at<2>();
119 constexpr index_t kMPerBlock = BlockGemmShape::kM;
120 constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
121 constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
122 constexpr index_t kNPerStep = NWarp * WG::kN;
123 constexpr index_t kN1 = 8;
124 constexpr index_t kN0 = kNPerStep / kN1;
125
126 constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor(
128 ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number<kN1>{}, number<1>{}),
129 number<kN1>{},
130 number<1>{});
131
132 constexpr auto randval_lds_block_desc = transform_tensor_descriptor(
133 randval_lds_block_desc_0,
139
140 return randval_lds_block_desc;
141 }
142
143 template <typename BlockGemm>
145 {
146 constexpr auto config =
147 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
148 using WG = remove_cvref_t<decltype(config.template at<0>())>;
149 constexpr bool IsWG32 = WG::kM == 32;
150 constexpr index_t MWarp = config.template at<1>();
151 constexpr index_t NWarp = config.template at<2>();
153 constexpr index_t kMPerBlock = BlockGemmShape::kM;
154 constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
155 constexpr index_t NIterPerWarp = 1;
156
157 // The tile distribution is different from the one in MakeRandValLdsShuffleTileDistribution,
158 // because it can combine 2 (MIterPerWarp) 16x16 subtiles for generating them at once
159 constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
166
167 // Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd.
168 constexpr auto randval_block_inner_part_dstr_encoding =
169 typename WarpGemmDispatcher<typename WG::ADataType,
170 typename WG::BDataType,
171 typename WG::CDataType,
172 WG::kM,
173 WG::kN,
174 WG::kK,
175 false,
176 IsWG32>::CWarpDstrEncoding{};
177
178 constexpr auto randval_block_part_dstr_encode =
179 detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
180 randval_block_inner_part_dstr_encoding);
181
182 return make_static_tile_distribution(randval_block_part_dstr_encode);
183 }
184
185 template <typename BlockGemm>
187 {
188 constexpr auto config =
189 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
190 using WG = remove_cvref_t<decltype(config.template at<0>())>;
191 constexpr bool IsWG32 = WG::kM == 32;
192 constexpr index_t MWarp = config.template at<1>();
193 constexpr index_t NWarp = config.template at<2>();
195 constexpr index_t kMPerBlock = BlockGemmShape::kM;
196 constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
197 constexpr index_t NIterPerWarp = 1;
198
199 constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
206
207 constexpr auto randval_block_part_dstr_encode =
208 detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
209 typename WG::CWarpDstrEncoding{});
210
211 return make_static_tile_distribution(randval_block_part_dstr_encode);
212 }
213
214 template <typename BlockGemm,
215 typename PComputeDataType,
216 typename RandValOutputDataType,
217 typename PComputeWindow,
218 typename RandValDramWindow>
219 CK_TILE_HOST_DEVICE void Run(void* randval_ptr,
220 const index_t start_n0_idx,
221 PComputeWindow& p_compute,
222 RandValDramWindow& randval_dram_window) const
223 {
224 constexpr auto config =
225 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
226 using WG = remove_cvref_t<decltype(config.template at<0>())>;
227 constexpr bool IsWG32 = WG::kM == 32;
228 constexpr index_t MWarp = config.template at<1>();
229 constexpr index_t NWarp = config.template at<2>();
231 constexpr index_t kMPerBlock = BlockGemmShape::kM;
232 constexpr index_t kNPerBlock = BlockGemmShape::kN;
233 constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
234 constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
235 constexpr index_t kNPerStep = NWarp * WG::kN;
236
237 // randval tile in LDS
239 reinterpret_cast<uint8_t*>(randval_ptr), MakeRandValLdsBlockDescriptor<BlockGemm>());
240
241 auto randval_lds_window = make_tile_window(
242 randval_lds, MakeRandValLdsBlockDescriptor<BlockGemm>().get_lengths(), {0, 0});
243
244 // register distribute
245 auto randval_dist_generated =
247
248 const auto randval_lds_read_window =
249 make_tile_window(randval_lds_window.get_bottom_tensor_view(),
250 randval_lds_window.get_window_lengths(),
251 randval_lds_window.get_window_origin(),
253
254 const index_t start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{});
255 const index_t iMWarp = get_warp_id() / NWarp;
256 const index_t iNWarp = get_warp_id() % NWarp;
257
258 auto generate_randval = [&](auto i_m0, auto i_n0) {
259 // Generate random numbers
260 uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize];
261 const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp;
262 const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp);
263 if constexpr(IsWG32)
264 {
265 // Generate the whole 32x32 tile at once (each tile consists of random numbers taken
266 // from a separate subsequence of Philox)
267 const unsigned long long ph_subsequence =
268 bit_cast<unsigned long long>(make_uint2(wg_m0, wg_n0));
269 const index_t ph_offset = get_lane_id();
270 const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
271 static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
272 ph.get_random_16x8(random_uint8_t, ph_subsequence);
273 }
274 else
275 {
276 // Generate one or two 16x16 subtiles of the 32x32 tile (depending on whether
277 // MIterPerWarp is equal to 1 or 2)
278 const unsigned long long ph_subsequence =
279 bit_cast<unsigned long long>(make_uint2(wg_m0 / 2, wg_n0 / 2));
280 const index_t subtile_m0 = wg_m0 % 2;
281 if constexpr(get_warp_size() == 32)
282 {
283 const index_t ph_offset = (get_lane_id() & 15) +
284 (((get_lane_id() >> 4) & 1) << 5) +
285 ((wg_n0 % 2) << 4);
286 const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
287 if constexpr(MIterPerWarp == 1)
288 {
289 static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
291 random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1);
292 }
293 else
294 {
295 static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
296 ph.get_random_16x8(random_uint8_t, ph_subsequence);
297 }
298 }
299 else
300 {
301 const index_t subtile_n0 = (get_lane_id() >> 4) & 1;
302 const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4);
303 const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
304 if constexpr(MIterPerWarp == 1)
305 {
306 static_assert(randval_dist_generated.kThreadElementSpaceSize == 4);
308 random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0);
309 }
310 else
311 {
312 static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
314 random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0);
315 }
316 }
317 }
318
319 constexpr auto randval_dist_generated_spans =
320 decltype(randval_dist_generated)::get_distributed_spans();
321 int i_random_idx = 0;
322 sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
323 sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
324 constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
325 randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
326 });
327 });
328 // Transpose randval using LDS
329 store_tile(randval_lds_window, randval_dist_generated);
331 const auto randval = load_tile(randval_lds_read_window);
333 return randval;
334 };
335
337 {
338 static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
339 static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
340 const auto randval = generate_randval(i_m0, i_n0);
341 // save to Global
342 const auto randval_store = cast_tile<RandValOutputDataType>(randval);
343 store_tile(randval_dram_window, randval_store);
344 move_tile_window(randval_dram_window, {0, kNPerStep});
345 });
346 move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
347 });
348 move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
349 }
350 static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
351 static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
352 const auto randval = generate_randval(i_m0, i_n0);
353 // Drop values of P based on the generated probabilities
354 constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
355 sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
356 sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
357 constexpr auto p_idx0 =
358 tile_distributed_index<i_m0 * MIterPerWarp +
359 idx0.impl_.template at<0>()>{};
360 constexpr auto p_idx1 =
362 idx1.impl_.template at<1>(),
363 idx1.impl_.template at<2>()>{};
364 constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
365 constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
366 p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
367 ? p_compute[p_idx] * rp_undrop
368 : PComputeDataType(0);
369 });
370 });
371 });
372 });
373 }
374
375 const unsigned long long ph_seed;
376 const unsigned long long ph_head_offset;
377 const float rp_undrop;
380};
381
382// TODO: IsWG32_ is not needed as template parameter and can be removed. IsDropout_ == false can be
383// replaced with NullBlockDropout. This requires changes in xformers and other libs.
384template <bool IsDropout_, bool IsWG32_, bool IsStoreRandval_>
386
387template <bool IsWG32_, bool IsStoreRandval_>
388struct BlockDropoutBwd<false, IsWG32_, IsStoreRandval_>
389{
390 static constexpr bool IsDropout = false;
391 static constexpr bool IsStoreRandval = IsStoreRandval_;
392
393 template <typename BlockGemm, bool IsFwd = false, typename RandValDramBlockWindowTmp>
394 CK_TILE_HOST_DEVICE static constexpr auto
395 MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
396 index_t seqlen_qk_start)
397 {
398 (void)randval_dram_block_window_tmp;
399 (void)seqlen_qk_start;
400
402 }
403};
404
405template <bool IsWG32_, bool IsStoreRandval_>
406struct BlockDropoutBwd<true, IsWG32_, IsStoreRandval_>
407{
408 static constexpr bool IsDropout = true;
409 static constexpr bool IsStoreRandval = IsStoreRandval_;
410
412 index_t i_head,
413 index_t nheads,
414 unsigned long long seed,
415 unsigned long long offset,
416 float rp_undrop_,
417 uint8_t p_undrop_in_uint8_t_)
419 ph_head_offset(amd_wave_read_first_lane(offset + (i_batch * nheads + i_head) *
420 detail::philox_per_tile)),
421 rp_undrop(rp_undrop_),
422 p_undrop_in_uint8_t(p_undrop_in_uint8_t_)
423 {
424 }
425
426 template <typename BlockGemm, bool IsFwd = false, typename RandValDramBlockWindowTmp>
427 CK_TILE_HOST_DEVICE static constexpr auto
428 MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
429 index_t seqlen_qk_start)
430 {
431 constexpr auto config =
432 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
433 using WG = remove_cvref_t<decltype(config.template at<0>())>;
434 constexpr bool IsWG32 = WG::kM == 32;
435 constexpr index_t MWarp = config.template at<1>();
436 constexpr index_t NWarp = config.template at<2>();
438 constexpr index_t kMPerBlock = BlockGemmShape::kM;
439 constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
440 constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
441 constexpr index_t kNPerStep = NWarp * WG::kN;
442
443 const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
444 auto randval_dram_window = [&]() {
445 if constexpr(IsFwd)
446 {
447 return make_tile_window(
448 randval_dram_block_window_tmp.get_bottom_tensor_view(),
450 {block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
451 }
452 else
453 {
454 return make_tile_window(
455 randval_dram_block_window_tmp.get_bottom_tensor_view(),
457 {seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
458 }
459 }();
460
461 return randval_dram_window;
462 }
463
464 template <typename BlockGemm>
466 {
467 constexpr auto config =
468 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
469 using WG = remove_cvref_t<decltype(config.template at<0>())>;
470 constexpr bool IsWG32 = WG::kM == 32;
471 constexpr index_t MWarp = config.template at<1>();
472 constexpr index_t NWarp = config.template at<2>();
474 constexpr index_t kMPerBlock = BlockGemmShape::kM;
475 constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
476 constexpr index_t NIterPerWarp = 1;
477
478 constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
485
486 constexpr auto randval_block_inner_part_dstr_encoding =
487 typename WarpGemmDispatcher<typename WG::ADataType,
488 typename WG::BDataType,
489 typename WG::CDataType,
490 WG::kM,
491 WG::kN,
492 WG::kK,
493 false,
494 IsWG32>::CWarpDstrEncoding{};
495 static_assert(
496 std::is_same_v<remove_cvref_t<decltype(randval_block_inner_part_dstr_encoding)>,
497 typename WG::CWarpDstrEncoding>);
498
499 constexpr auto randval_block_part_dstr_encode =
500 detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
501 randval_block_inner_part_dstr_encoding);
502
503 return make_static_tile_distribution(randval_block_part_dstr_encode);
504 }
505
506 template <typename BlockGemm,
507 typename RandValOutputDataType,
508 typename PComputeWindow,
509 typename RandValDramWindow>
510 CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx,
511 const index_t start_n0_idx,
512 PComputeWindow& p_compute,
513 RandValDramWindow& randval_dram_window) const
514 {
515 constexpr auto config =
516 BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
517 using WG = remove_cvref_t<decltype(config.template at<0>())>;
518 constexpr bool IsWG32 = WG::kM == 32;
519 constexpr index_t MWarp = config.template at<1>();
520 constexpr index_t NWarp = config.template at<2>();
522 constexpr index_t kMPerBlock = BlockGemmShape::kM;
523 constexpr index_t kNPerBlock = BlockGemmShape::kN;
524 constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
525 constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
526 constexpr index_t kNPerStep = NWarp * WG::kN;
527
528 // register distribute
529 auto randval_dist_generated =
531
532 const index_t iMWarp = get_warp_id() / NWarp;
533 const index_t iNWarp = get_warp_id() % NWarp;
534
535 auto generate_randval = [&](auto i_m0, auto i_n0) {
536 // Generate random numbers
537 uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize];
538 const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp;
539 const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp);
540 if constexpr(IsWG32)
541 {
542 // Generate the whole 32x32 tile at once (each tile consists of random numbers
543 // taken from a separate subsequence of Philox)
544 const unsigned long long ph_subsequence =
545 bit_cast<unsigned long long>(make_uint2(wg_m0, wg_n0));
546 const index_t ph_offset = get_lane_id();
547 const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
548 static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
549 ph.get_random_16x8(random_uint8_t, ph_subsequence);
550 }
551 else
552 {
553 // Generate one or two 16x16 subtiles of the 32x32 tile (depending on whether
554 // MIterPerWarp is equal to 1 or 2)
555 const unsigned long long ph_subsequence =
556 bit_cast<unsigned long long>(make_uint2(wg_m0 / 2, wg_n0 / 2));
557 const index_t subtile_m0 = wg_m0 % 2;
558 if constexpr(get_warp_size() == 32)
559 {
560 const index_t ph_offset = (get_lane_id() & 15) +
561 (((get_lane_id() >> 4) & 1) << 5) +
562 ((wg_n0 % 2) << 4);
563 const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
564 if constexpr(MIterPerWarp == 1)
565 {
566 static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
568 random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1);
569 }
570 else
571 {
572 static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
573 ph.get_random_16x8(random_uint8_t, ph_subsequence);
574 }
575 }
576 else
577 {
578 const index_t subtile_n0 = (get_lane_id() >> 4) & 1;
579 const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4);
580 const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
581 if constexpr(MIterPerWarp == 1)
582 {
583 static_assert(randval_dist_generated.kThreadElementSpaceSize == 4);
585 random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0);
586 }
587 else
588 {
589 static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
591 random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0);
592 }
593 }
594 }
595
596 constexpr auto randval_dist_generated_spans =
597 decltype(randval_dist_generated)::get_distributed_spans();
598 int i_random_idx = 0;
599 sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
600 sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
601 constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
602 randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
603 });
604 });
605 return randval_dist_generated;
606 };
607
608 static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
609 static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
610 const auto randval = generate_randval(i_m0, i_n0);
611 // Drop values of P based on the generated probabilities, negative sign is used to
612 // distinguish such values ​​later in bwd pipeline.
613 constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
614 sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
615 sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
616 constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
617 constexpr auto p_idx0 =
618 tile_distributed_index<i_m0 * MIterPerWarp +
619 idx0.impl_.template at<0>(),
620 idx0.impl_.template at<1>(),
621 idx0.impl_.template at<2>()>{};
622 constexpr auto p_idx1 = tile_distributed_index<i_n0>{};
623 constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
624 p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
625 ? p_compute[p_idx]
626 : -p_compute[p_idx];
627 });
628 });
629 // save to Global
630 if constexpr(IsStoreRandval)
631 {
632 const auto randval_store = cast_tile<RandValOutputDataType>(randval);
633 store_tile(randval_dram_window, randval_store);
634 move_tile_window(randval_dram_window, {kMPerStep, 0});
635 }
636 });
637 if constexpr(IsStoreRandval)
638 {
639 move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep});
640 }
641 });
642 if constexpr(IsStoreRandval)
643 {
644 move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock});
645 }
646 }
647
648 const unsigned long long ph_seed;
649 const unsigned long long ph_head_offset;
650 const float rp_undrop;
652};
653
654} // namespace ck_tile
Definition philox_rand.hpp:12
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t *out, const unsigned long long subsequence, const index_t idx) const
Definition philox_rand.hpp:75
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t *out, const unsigned long long subsequence, const index_t idx0, const index_t idx1) const
Definition philox_rand.hpp:56
CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t *out, const unsigned long long subsequence) const
Definition philox_rand.hpp:42
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition arch.hpp:385
constexpr index_t philox_per_tile
Definition block_dropout.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition tile_distribution_encoding.hpp:457
Definition tile/core/algorithm/cluster_descriptor.hpp:13
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition warp_gemm_dispatcher.hpp:182
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE index_t get_lane_id()
Definition arch.hpp:101
CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength &low_length)
Definition coordinate_transform.hpp:1558
CK_TILE_HOST_DEVICE constexpr index_t get_warp_size()
Definition arch.hpp:63
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition tile/core/arch/amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_HOST_DEVICE constexpr auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition tile/core/tensor/tensor_descriptor.hpp:274
CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths &low_lengths)
Definition coordinate_transform.hpp:1615
CK_TILE_DEVICE index_t get_warp_id(bool_constant< ReturnSgpr >={})
Definition arch.hpp:104
CK_TILE_HOST_DEVICE constexpr Y bit_cast(const X &x)
Definition bit_cast.hpp:11
CK_TILE_HOST_DEVICE constexpr auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition tile/core/tensor/tensor_descriptor.hpp:203
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE constexpr auto make_null_tile_window(const WindowLengths &window_lengths)
Definition null_tile_window.hpp:66
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition tile_distribution.hpp:480
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
unsigned char uint8_t
Definition stdint.h:124
static CK_TILE_HOST_DEVICE constexpr auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition block_dropout.hpp:395
static constexpr bool IsStoreRandval
Definition block_dropout.hpp:391
static constexpr bool IsDropout
Definition block_dropout.hpp:390
CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch, index_t i_head, index_t nheads, unsigned long long seed, unsigned long long offset, float rp_undrop_, uint8_t p_undrop_in_uint8_t_)
Definition block_dropout.hpp:411
const unsigned long long ph_seed
Definition block_dropout.hpp:648
static CK_TILE_HOST_DEVICE constexpr auto MakeRandValTileDistribution()
Definition block_dropout.hpp:465
static constexpr bool IsStoreRandval
Definition block_dropout.hpp:409
const uint8_t p_undrop_in_uint8_t
Definition block_dropout.hpp:651
const unsigned long long ph_head_offset
Definition block_dropout.hpp:649
static CK_TILE_HOST_DEVICE constexpr auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition block_dropout.hpp:428
static constexpr bool IsDropout
Definition block_dropout.hpp:408
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx, const index_t start_n0_idx, PComputeWindow &p_compute, RandValDramWindow &randval_dram_window) const
Definition block_dropout.hpp:510
const float rp_undrop
Definition block_dropout.hpp:650
Definition block_dropout.hpp:385
const uint8_t p_undrop_in_uint8_t
Definition block_dropout.hpp:378
CK_TILE_HOST_DEVICE BlockDropout(index_t i_batch, index_t i_head, index_t nheads, unsigned long long seed, unsigned long long offset, float rp_undrop_, uint8_t p_undrop_in_uint8_t_, bool is_store_randval_)
Definition block_dropout.hpp:54
const float rp_undrop
Definition block_dropout.hpp:377
static CK_TILE_HOST_DEVICE constexpr auto MakeRandValTileDistribution()
Definition block_dropout.hpp:144
const unsigned long long ph_head_offset
Definition block_dropout.hpp:376
static CK_TILE_HOST_DEVICE constexpr auto MakeRandValLdsBlockDescriptor()
Definition block_dropout.hpp:110
const bool is_store_randval
Definition block_dropout.hpp:379
static CK_TILE_HOST_DEVICE constexpr auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition block_dropout.hpp:73
CK_TILE_HOST_DEVICE void Run(void *randval_ptr, const index_t start_n0_idx, PComputeWindow &p_compute, RandValDramWindow &randval_dram_window) const
Definition block_dropout.hpp:219
const unsigned long long ph_seed
Definition block_dropout.hpp:375
static CK_TILE_HOST_DEVICE constexpr auto MakeRandValLdsShuffleTileDistribution()
Definition block_dropout.hpp:186
Definition block_dropout.hpp:39
static CK_TILE_HOST_DEVICE constexpr auto MakeRandvalDramWindow(RandValDramBlockWindowTmp &randval_dram_block_window_tmp, index_t seqlen_qk_start)
Definition block_dropout.hpp:42
Definition coordinate_transform.hpp:1392
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
Definition tile_distribution.hpp:42
static constexpr auto impl_
Definition tile_distribution.hpp:45
Definition tile_distribution_encoding.hpp:26
Definition tile/core/container/tuple.hpp:192