14template <
typename Problem_,
typename Policy_ = BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy>
39 static_assert(
kQLoadOnce == Policy::QLoadOnce);
40 static constexpr bool kKLoadOnce = BlockFmhaShape::kM0 >= 64;
54 static_assert(
kSubQKHeaddim <= 256,
"hdim bigger than 256 is not suitable for this pipeline!");
63 Problem::kPadHeadDimQ;
65 Problem::kPadHeadDimV;
69 static constexpr auto BiasEnum = Problem::BiasEnum;
70 static constexpr bool kStoreLSE = Problem::kStoreLSE;
83 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
84 return Policy::template GetAlignmentV<Problem>();
86 return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
92 kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
95 if constexpr(Problem::kBlockPerCu != -1)
96 return Problem::kBlockPerCu;
125 static constexpr const char*
name =
"qr_async_trload";
133 template <
typename QDramBlockWindowTmp,
134 typename KDramBlockWindowTmp,
135 typename VDramBlockWindowTmp,
136 typename BiasDramBlockWindowTmp,
137 typename LSEaccDramBlockWindowTmp,
138 typename PositionEncoding>
140 operator()(
const QDramBlockWindowTmp& q_dram_block_window_tmp,
141 const KDramBlockWindowTmp& k_dram_block_window_tmp,
142 const VDramBlockWindowTmp& v_dram_block_window_tmp,
143 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
144 LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp,
146 PositionEncoding position_encoding,
148 void* smem_ptr)
const
151 std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
152 std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
153 std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
156 static_assert(
kM0 == QDramBlockWindowTmp{}.get_window_lengths()[
I0] &&
158 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[
I0] &&
159 kK0 == KDramBlockWindowTmp{}.get_window_lengths()[
I1] &&
160 kN1 == VDramBlockWindowTmp{}.get_window_lengths()[
I0] &&
161 kK1 == VDramBlockWindowTmp{}.get_window_lengths()[
I1] &&
162 kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[
I0] &&
163 kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[
I1],
165 ignore = bias_dram_block_window_tmp;
166 ignore = position_encoding;
168 constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
169 constexpr auto gemm_1 = Policy::template GetPVBlockGemm<Problem>();
171 using SaccBlockTileType =
decltype(gemm_0.MakeCBlockTile());
172 auto s_acc = SaccBlockTileType{};
175 const auto f_max = [](
auto e0,
auto e1) {
return max(e0, e1); };
176 const auto f_sum = [](
auto e0,
auto e1) {
return e0 + e1; };
178 using OaccBlockTileType =
decltype(gemm_1.MakeCBlockTile());
180 auto o_acc = OaccBlockTileType{};
189 auto m = MLBlockTileType{};
190 auto l = MLBlockTileType{};
196 const auto q_origin = q_dram_block_window_tmp.get_window_origin();
197 const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
203 const index_t logical_num_total_loop =
205 if(logical_num_total_loop <= 0)
225 q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution<Problem>());
228 static_cast<QDataType*
>(smem_ptr), Policy::template MakeQLdsBlockDescriptor<Problem>());
232 Policy::template MakeQLdsBlockDescriptor<Problem, true>());
234 auto q_lds_store_window =
236 Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
239 auto q_lds_read_window =
241 Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
243 Policy::template MakeQRegTileDistribution<Problem>());
248 const index_t physical_seqlen_k_start = logical_seqlen_k_start;
249 const index_t physical_seqlen_k_end = logical_seqlen_k_end;
254 const index_t aligned_physical_seqlen_k_start = physical_seqlen_k_start;
258 {physical_seqlen_k_start, 0},
259 Policy::template MakeKDramTileDistribution<Problem>());
262 static_cast<KDataType*
>(smem_ptr), Policy::template MakeKLdsBlockDescriptor<Problem>());
265 Policy::template MakeKLdsBlockDescriptor<Problem, false, true>());
267 auto k_lds_write_window =
269 Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(),
271 auto k_lds_read_window =
275 Policy::template MakeKRegTileDistribution<Problem>());
279 reinterpret_cast<SaccDataType*
>(
reinterpret_cast<char*
>(smem_ptr) +
280 Policy::template GetSmemSizeK<Problem>()),
281 Policy::template MakeSLdsBlockDescriptor<Problem>());
283 s_lds, Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
284 auto s_read_lds_window =
286 Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(),
288 Policy::template MakeSRegTileDistribution<Problem>());
293 {physical_seqlen_k_start, 0},
294 Policy::template MakeVDramTileDistribution<Problem>());
297 reinterpret_cast<VDataType*
>(
static_cast<char*
>(smem_ptr) +
298 Policy::template GetSmemSizeK<Problem>() +
299 Policy::template GetSmemSizeS<Problem>()),
300 Policy::template MakeVLdsBlockDescriptor<Problem>());
302 reinterpret_cast<VDataType*
>(
static_cast<char*
>(smem_ptr) +
303 Policy::template GetSmemSizeK<Problem>() +
304 Policy::template GetSmemSizeS<Problem>()),
305 Policy::template MakeVLdsBlockDescriptor<Problem, true>());
306 auto v_lds_write_window =
308 Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(),
311 auto v_lds_read_window =
315 Policy::template MakeVRegTileDistribution<Problem>());
318 auto q_tile =
load_tile(q_lds_read_window);
327 static_assert(1 <= k0_loops);
328 static_assert(1 <= k1_loops);
333 constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access();
334 constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access();
347 if constexpr(1 < k0_loops)
349 static_for<0, k0_loops - 1, 1>{}([&](
auto i_k0) {
350 if constexpr(i_k0 == 0)
359 auto k_tile =
load_tile(k_lds_read_window);
376 if constexpr(k0_loops == 1)
385 auto k_tile =
load_tile(k_lds_read_window);
395 if(i_total_loops == (num_total_loop - 1))
397 const auto k_origin =
398 make_tuple(
kN0 * i_total_loops + physical_seqlen_k_start, 0);
402 physical_seqlen_k_start_ = physical_seqlen_k_start,
403 physical_seqlen_k_end_ = physical_seqlen_k_end](
auto tile_idx) {
404 const auto col = k_origin.at(
I0) + tile_idx.at(
I1);
407 return physical_seqlen_k_end_ <= col;
415 const auto k_origin =
make_tuple(
kN0 * i_total_loops + physical_seqlen_k_start, 0);
417 bool need_perpixel_check =
419 if(need_perpixel_check)
423 const auto row = q_origin.at(
I0) + tile_idx.at(
I0);
424 const auto col = k_origin.at(
I0) + tile_idx.at(
I1);
425 return mask.IsOutOfBound(row, col);
463 const auto m_old = m;
465 [](
auto& e0,
auto e1,
auto e2) { e0 =
max(e1, e2); }, m, m_old, m_local);
468 s_new.get_tile_distribution());
486 constexpr auto p_spans =
decltype(p_compute)::get_distributed_spans();
489 auto row_max = scale_s * get_validated_m(m[i_idx]);
491 constexpr auto i_j_idx =
make_tuple(idx0, idx1);
495 p_compute(i_j_idx) =
exp2(s_new[i_j_idx] - get_validated_m(m[i_idx]));
501 p_compute(i_j_idx) =
exp2(s_new[i_j_idx] - get_validated_m(m[i_idx]));
505 p_compute(i_j_idx) =
exp2(scale_s * s_new[i_j_idx] - row_max);
518 Policy::template MakePRegTileDistribution<Problem>());
522 constexpr auto o_spans =
decltype(o_acc)::get_distributed_spans();
525 const auto tmp = [&]() {
529 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
535 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
539 auto row_max = scale_s * get_validated_m(m[i_idx]);
540 return exp2(scale_s * m_old[i_idx] - row_max);
544 l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
546 constexpr auto i_j_idx =
make_tuple(idx0, idx1);
548 o_acc(i_j_idx) *= tmp;
556 if constexpr(1 < k1_loops)
558 static_for<0, k1_loops - 1, 1>{}([&](
auto i_k1) {
579 }
while(++i_total_loops < num_total_loop);
586 constexpr auto lse_acc_spans =
decltype(lse_acc)::get_distributed_spans();
592 lse_acc(i_idx) = m_[i_idx] /
C_LOG2E +
log(l_[i_idx]);
598 lse_acc(i_idx) = m_[i_idx] /
C_LOG2E +
log(l_[i_idx]);
602 lse_acc(i_idx) = m_[i_idx] * scale_s /
C_LOG2E +
log(l_[i_idx]);
611 constexpr auto o_spans =
decltype(o_acc)::get_distributed_spans();
615 const auto tmp = [&]() {
619 return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
625 constexpr auto i_j_idx =
make_tuple(idx0, idx1);
626 o_acc(i_j_idx) *= tmp;
634 template <
typename QDramBlockWindowTmp,
635 typename KDramBlockWindowTmp,
636 typename VDramBlockWindowTmp,
637 typename BiasDramBlockWindowTmp,
638 typename LSEaccDramBlockWindowTmp,
639 typename PositionEncoding>
641 operator()(
const QDramBlockWindowTmp& __restrict__ q_dram_block_window_tmp,
642 const KDramBlockWindowTmp& __restrict__ k_dram_block_window_tmp,
643 const VDramBlockWindowTmp& __restrict__ v_dram_block_window_tmp,
644 const BiasDramBlockWindowTmp& __restrict__ bias_dram_block_window_tmp,
645 LSEaccDramBlockWindowTmp& __restrict__ lse_acc_dram_window_tmp,
647 PositionEncoding position_encoding,
649 void* __restrict__ smem_ptrk0,
650 void* __restrict__ smem_ptrk1,
651 void* __restrict__ smem_ptrv0,
652 void* __restrict__ smem_ptrv1)
const
655 std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
656 std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
657 std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
660 static_assert(
kM0 == QDramBlockWindowTmp{}.get_window_lengths()[
I0] &&
662 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[
I0] &&
663 kK0 == KDramBlockWindowTmp{}.get_window_lengths()[
I1] &&
664 kN1 == VDramBlockWindowTmp{}.get_window_lengths()[
I0] &&
665 kK1 == VDramBlockWindowTmp{}.get_window_lengths()[
I1] &&
666 kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[
I0] &&
667 kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[
I1],
669 ignore = bias_dram_block_window_tmp;
670 ignore = position_encoding;
673 constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
674 constexpr auto gemm_1 = Policy::template GetPVBlockGemm<Problem>();
676 using SaccBlockTileType =
decltype(gemm_0.MakeCBlockTile());
677 auto s_acc = SaccBlockTileType{};
680 const auto f_max = [](
auto e0,
auto e1) {
return max(e0, e1); };
681 const auto f_sum = [](
auto e0,
auto e1) {
return e0 + e1; };
683 using OaccBlockTileType =
decltype(gemm_1.MakeCBlockTile());
685 auto o_acc = OaccBlockTileType{};
694 auto m = MLBlockTileType{};
695 auto l = MLBlockTileType{};
701 const auto q_origin = q_dram_block_window_tmp.get_window_origin();
702 const auto [logical_seqlen_k_start, logical_seqlen_k_end] =
708 const index_t logical_num_total_loop =
710 if(logical_num_total_loop <= 0)
730 q_dram_block_window_tmp, Policy::template MakeQDramTileDistribution<Problem>());
734 Policy::template MakeQLdsBlockDescriptor<Problem>());
738 Policy::template MakeQLdsBlockDescriptor<Problem, true>());
740 auto q_lds_store_window =
742 Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
745 auto q_lds_read_window =
747 Policy::template MakeQLdsBlockDescriptor<Problem>().get_lengths(),
749 Policy::template MakeQRegTileDistribution<Problem>());
753 auto q_tile =
load_tile(q_lds_read_window);
756 const index_t physical_seqlen_k_start = logical_seqlen_k_start;
757 const index_t physical_seqlen_k_end = logical_seqlen_k_end;
762 const index_t aligned_physical_seqlen_k_start = physical_seqlen_k_start;
766 {physical_seqlen_k_start, 0},
767 Policy::template MakeKDramTileDistribution<Problem, true>());
770 static_cast<KDataType* __restrict__
>(smem_ptrk0),
771 Policy::template MakeKLdsBlockDescriptor<Problem, true>());
774 static_cast<KDataType* __restrict__
>(smem_ptrk0),
775 Policy::template MakeKLdsBlockDescriptor<Problem, true, true>());
777 auto k_lds_write_window =
779 Policy::template MakeKLdsBlockDescriptor<Problem>().get_lengths(),
782 auto k_lds_read_window =
786 Policy::template MakeKRegTileDistribution<Problem>());
790 reinterpret_cast<SaccDataType*
>(
reinterpret_cast<char*
>(smem_ptrk0) +
791 Policy::template GetSmemSizeK<Problem>()),
792 Policy::template MakeSLdsBlockDescriptor<Problem>());
794 s_lds, Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
795 auto s_read_lds_window =
797 Policy::template MakeSLdsBlockDescriptor<Problem>().get_lengths(),
799 Policy::template MakeSRegTileDistribution<Problem>());
804 {physical_seqlen_k_start, 0},
805 Policy::template MakeVDramTileDistribution<Problem>());
808 reinterpret_cast<VDataType* __restrict__
>(
static_cast<char*
>(smem_ptrv0)),
809 Policy::template MakeVLdsBlockDescriptor<Problem>());
812 reinterpret_cast<VDataType* __restrict__
>(
static_cast<char*
>(smem_ptrv0)),
813 Policy::template MakeVLdsBlockDescriptor<Problem, true>());
815 auto v_lds_write_window =
817 Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(),
820 auto v_lds_read_window =
824 Policy::template MakeVRegTileDistribution<Problem>());
836 static_assert(1 <= k0_loops);
837 static_assert(1 <= k1_loops);
843 k_lds_write_window.set_bottom_tensor_view_data_ptr(
844 static_cast<KDataType* __restrict__
>(smem_ptrk1));
847 constexpr index_t k_vmem_insts = k_dram_window.get_num_of_access();
848 constexpr index_t v_vmem_insts = v_dram_window.get_num_of_access();
850 constexpr index_t k_lds_insts = k_lds_read_window.get_num_of_access();
851 constexpr index_t v_lds_insts = v_lds_read_window.get_num_of_access();
854 auto k_tile =
load_tile(k_lds_read_window);
856 __builtin_amdgcn_sched_barrier(0);
858 auto mainloop = [&](
KDataType* __restrict__ k_lds_write_ptr,
861 KDataType* __restrict__ v_lds_read_ptr) {
865 v_lds_write_window.set_bottom_tensor_view_data_ptr(v_lds_write_ptr);
871 if constexpr(1 < k0_loops)
873 static_for<0, k0_loops - 1, 1>{}([&](
auto i_k0) {
876 auto k_tile_switch =
load_tile(k_lds_read_window);
884 k_tile = k_tile_switch;
897 v_lds_read_window.set_bottom_tensor_view_data_ptr(v_lds_read_ptr);
902 if(i_total_loops == (num_total_loop - 1))
904 const auto k_origin =
905 make_tuple(
kN0 * i_total_loops + physical_seqlen_k_start, 0);
909 physical_seqlen_k_start_ = physical_seqlen_k_start,
910 physical_seqlen_k_end_ = physical_seqlen_k_end](
auto tile_idx) {
911 const auto col = k_origin.at(
I0) + tile_idx.at(
I1);
914 return physical_seqlen_k_end_ <= col;
922 const auto k_origin =
make_tuple(
kN0 * i_total_loops + physical_seqlen_k_start, 0);
924 bool need_perpixel_check =
926 if(need_perpixel_check)
930 const auto row = q_origin.at(
I0) + tile_idx.at(
I0);
931 const auto col = k_origin.at(
I0) + tile_idx.at(
I1);
932 return mask.IsOutOfBound(row, col);
963 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
964 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
969 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
970 __builtin_amdgcn_sched_group_barrier(0x100, 2, 0);
973 const auto m_old = m;
975 [](
auto& e0,
auto e1,
auto e2) { e0 =
max(e1, e2); }, m, m_old, m_local);
978 s_new.get_tile_distribution());
996 constexpr auto p_spans =
decltype(p_compute)::get_distributed_spans();
999 auto row_max = scale_s * get_validated_m(m[i_idx]);
1001 constexpr auto i_j_idx =
make_tuple(idx0, idx1);
1005 p_compute(i_j_idx) =
exp2(s_new[i_j_idx] - get_validated_m(m[i_idx]));
1011 p_compute(i_j_idx) =
exp2(s_new[i_j_idx] - get_validated_m(m[i_idx]));
1015 p_compute(i_j_idx) =
exp2(scale_s * s_new[i_j_idx] - row_max);
1028 Policy::template MakePRegTileDistribution<Problem>());
1032 constexpr auto o_spans =
decltype(o_acc)::get_distributed_spans();
1035 const auto tmp = [&]() {
1039 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
1045 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
1049 auto row_max = scale_s * get_validated_m(m[i_idx]);
1050 return exp2(scale_s * m_old[i_idx] - row_max);
1054 l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
1056 constexpr auto i_j_idx =
make_tuple(idx0, idx1);
1058 o_acc(i_j_idx) *= tmp;
1064 k_lds_write_window.set_bottom_tensor_view_data_ptr(k_lds_write_ptr);
1067 if constexpr(1 < k1_loops)
1069 static_for<0, k1_loops - 1, 1>{}([&](
auto i_k1) {
1080 v_tile = v_tile_switch;
1093 k_lds_read_window.set_bottom_tensor_view_data_ptr(k_lds_read_ptr);
1098 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
1099 __builtin_amdgcn_sched_group_barrier(0x100, 2, 0);
1104 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
1105 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
1111 bool is_even_loop = i_total_loops % 2 == 0;
1112 auto k_lds_write_ptr = is_even_loop ?
static_cast<KDataType* __restrict__
>(smem_ptrk0)
1113 :
static_cast<KDataType* __restrict__
>(smem_ptrk1);
1114 auto k_lds_read_ptr = is_even_loop ?
static_cast<KDataType* __restrict__
>(smem_ptrk1)
1115 :
static_cast<KDataType* __restrict__
>(smem_ptrk0);
1116 auto v_lds_write_ptr = is_even_loop ?
static_cast<VDataType* __restrict__
>(smem_ptrv1)
1117 :
static_cast<VDataType* __restrict__
>(smem_ptrv0);
1118 auto v_lds_read_ptr = is_even_loop ?
static_cast<VDataType* __restrict__
>(smem_ptrv0)
1119 :
static_cast<VDataType* __restrict__
>(smem_ptrv1);
1120 mainloop(k_lds_write_ptr, k_lds_read_ptr, v_lds_write_ptr, v_lds_read_ptr);
1122 }
while(i_total_loops < num_total_loop);
1129 constexpr auto lse_acc_spans =
decltype(lse_acc)::get_distributed_spans();
1135 lse_acc(i_idx) = m_[i_idx] /
C_LOG2E +
log(l_[i_idx]);
1141 lse_acc(i_idx) = m_[i_idx] /
C_LOG2E +
log(l_[i_idx]);
1145 lse_acc(i_idx) = m_[i_idx] * scale_s /
C_LOG2E +
log(l_[i_idx]);
1150 store_tile(lse_acc_dram_window_tmp, lse_acc);
1154 constexpr auto o_spans =
decltype(o_acc)::get_distributed_spans();
1158 const auto tmp = [&]() {
1160 FmhaMask::IsMasking)
1162 return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
1165 return 1 / l[i_idx];
1168 constexpr auto i_j_idx =
make_tuple(idx0, idx1);
1169 o_acc(i_j_idx) *= tmp;
#define CK_TILE_FMHA_FWD_FAST_EXP2
Definition config.hpp:234
#define CK_TILE_HOST_DEVICE
Definition config.hpp:42
Definition tile/core/algorithm/cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition bfloat16.hpp:428
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition type_traits.hpp:21
CK_TILE_DEVICE auto async_load_tile(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:119
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition tile_elementwise.hpp:95
CK_TILE_HOST_DEVICE constexpr auto make_tensor_view(DataType *__restrict__ p, const tensor_descriptor< Ts... > &desc)
Definition tensor_view.hpp:452
CK_TILE_DEVICE constexpr auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition slice_tile.hpp:23
@ ALIBI
Definition block_attention_bias_enum.hpp:15
@ NO_BIAS
Definition block_attention_bias_enum.hpp:13
@ ELEMENTWISE_BIAS
Definition block_attention_bias_enum.hpp:14
constant< b > bool_constant
Definition tile/core/numeric/integral_constant.hpp:43
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={}, bool_constant< CrossWarp >={})
Definition block_reduce.hpp:21
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition tile_elementwise.hpp:23
CK_TILE_DEVICE void block_sync_lds_direct_load()
Definition arch.hpp:288
CK_TILE_DEVICE void block_sync_lds()
Definition arch.hpp:282
CK_TILE_HOST_DEVICE constexpr auto make_static_distributed_tensor(const StaticTileDistribution &)
Definition static_distributed_tensor.hpp:142
CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window)
transpose loads tile from a tensor and returns the resulting tensor with a new (transposed) tile dist...
Definition load_tile_transpose.hpp:403
constant< v > number
Definition tile/core/numeric/integral_constant.hpp:37
CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition null_tile_window.hpp:75
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition tile_elementwise.hpp:327
constexpr detail::ignore_t ignore
Definition tile/core/utility/ignore.hpp:20
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition block_reduce.hpp:191
CK_TILE_HOST_DEVICE constexpr auto integer_divide_ceil(X x, Y y)
Definition tile/core/numeric/math.hpp:149
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition sweep_tile.hpp:20
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition null_tile_window.hpp:95
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition static_distributed_tensor.hpp:175
CK_TILE_HOST_DEVICE constexpr T max(T x)
Definition tile/core/numeric/math.hpp:161
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition store_tile.hpp:23
int32_t index_t
Definition integer.hpp:9
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition tile_elementwise.hpp:177
CK_TILE_HOST_DEVICE constexpr Y type_convert(X x)
Definition tile/core/numeric/type_convert.hpp:29
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition bfloat16.hpp:425
CK_TILE_HOST_DEVICE constexpr auto make_tuple(Xs &&... xs)
Definition tile/core/container/tuple.hpp:360
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:16
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:27
static constexpr bool kPadHeadDimV
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:64
static constexpr index_t kN0
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:45
remove_cvref_t< typename Problem::VDataType > VDataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:24
remove_cvref_t< typename Problem::QDataType > QDataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:22
static constexpr bool kHasDropout
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:68
static constexpr bool kPadHeadDimQ
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:62
remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:26
static constexpr bool kPadSeqLenK
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:61
remove_cvref_t< typename Problem::KDataType > KDataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:23
static constexpr index_t kN1
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:47
static constexpr index_t kAlignmentQ
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:80
static constexpr index_t kNXdl
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:52
static constexpr index_t kAlignmentK
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:81
static constexpr index_t kBlockPerCu
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:94
static constexpr index_t kQKHeaddim
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:49
static constexpr index_t kNWarp
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:51
remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:25
static constexpr bool kStoreLSE
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:70
static constexpr bool kQLoadOnce
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:38
static constexpr index_t kAlignmentV
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:82
static constexpr bool kHasUnevenSplits
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:71
static constexpr index_t kSubQKHeaddim
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:50
remove_cvref_t< typename Problem::RandValOutputDataType > RandValOutputDataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:28
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &__restrict__ q_dram_block_window_tmp, const KDramBlockWindowTmp &__restrict__ k_dram_block_window_tmp, const VDramBlockWindowTmp &__restrict__ v_dram_block_window_tmp, const BiasDramBlockWindowTmp &__restrict__ bias_dram_block_window_tmp, LSEaccDramBlockWindowTmp &__restrict__ lse_acc_dram_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float scale_s, void *__restrict__ smem_ptrk0, void *__restrict__ smem_ptrk1, void *__restrict__ smem_ptrv0, void *__restrict__ smem_ptrv1) const
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:641
static constexpr index_t kM0
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:44
static constexpr index_t kAlignmentOacc
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:89
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:34
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:29
remove_cvref_t< Problem_ > Problem
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:20
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:37
remove_cvref_t< typename Problem::PDataType > PDataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:30
static constexpr index_t kK1
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:48
static constexpr bool kKLoadOnce
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:40
static constexpr bool kHasLogitsSoftCap
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:67
static constexpr index_t kK0
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:46
static constexpr auto I1
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:18
static CK_TILE_HOST_DEVICE constexpr ck_tile::index_t GetSmemSize()
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:127
remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:31
static constexpr bool kPadSeqLenQ
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:60
remove_cvref_t< typename Problem::AttentionVariant > AttentionVariant
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:33
remove_cvref_t< Policy_ > Policy
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:21
static constexpr const char * name
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:125
static constexpr index_t kBlockSize
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:42
remove_cvref_t< typename Problem::ODataType > ODataType
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:32
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, LSEaccDramBlockWindowTmp &lse_acc_dram_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float scale_s, void *smem_ptr) const
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:140
static constexpr index_t kAlignmentBias
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:91
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:36
static constexpr auto BiasEnum
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:69
static constexpr auto I0
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:17
static constexpr bool kIsGroupMode
Definition block_fmha_pipeline_qr_ks_vs_async_trload.hpp:59
static CK_TILE_HOST_DEVICE constexpr T infinity()
Definition tile/core/numeric/numeric.hpp:38
Definition tile/core/container/sequence.hpp:49
Definition tile/core/utility/functional.hpp:43
#define C_LOG2E
Definition tile/core/numeric/math.hpp:469