7#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__
8#define CK_MX_FP8_CVT_FAST_PATH 1
10#define CK_MX_FP8_CVT_FAST_PATH 0
16#if CK_MX_FP8_CVT_FAST_PATH
17template <ck_fp8_
interpretation_t
interpret>
18static __device__
float cast_to_f32_from_f8_scaled(
float scale,
fp8_storage_t v)
23 unsigned char i8val[4];
29 "Only OCP interpretations are supported");
33 return __builtin_amdgcn_cvt_scalef32_f32_fp8(val.i32val, scale, 0);
37 return __builtin_amdgcn_cvt_scalef32_f32_bf8(val.i32val, scale, 0);
41template <ck_fp8_
interpretation_t
interpret>
48 "Only OCP interpretations are supported");
52 return __builtin_amdgcn_cvt_scalef32_pk_f32_fp8(i16val, scale, 0);
56 return __builtin_amdgcn_cvt_scalef32_pk_f32_bf8(i16val, scale, 0);
60template <ck_fp8_
interpretation_t
interpret,
bool stochastic_rounding = false>
61static __device__
fp8_storage_t cast_to_f8_from_f32_scaled(
float v,
75 vector_type<int16_t, 2>::type v2i16;
82 if constexpr(stochastic_rounding)
86 ? __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, val.fval, rng, scale, 0)
87 : __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, val.fval, rng, scale, 0);
99 ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32( ret.v2i16,
108 ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32( ret.v2i16,
115 i8data = ret.v4i8[0];
120template <ck_fp8_
interpretation_t
interpret,
bool stochastic_rounding = false>
122 unsigned int rng = 0,
129 vector_type<int16_t, 2>::type v2i16;
130 StaticallyIndexedArray<fp8x2_storage_t, 2> v2f8x2;
133 if constexpr(stochastic_rounding)
138 ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[0], rng, scale, 0);
140 ret.ival = __builtin_amdgcn_cvt_scalef32_sr_fp8_f32(ret.ival, v[1], rng, scale, 0);
145 ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[0], rng, scale, 0);
147 ret.ival = __builtin_amdgcn_cvt_scalef32_sr_bf8_f32(ret.ival, v[1], rng, scale, 0);
160 ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_fp8_f32( ret.v2i16,
169 ret.v2i16 = __builtin_amdgcn_cvt_scalef32_pk_bf8_f32( ret.v2i16,
182#if CK_MX_FP8_CVT_FAST_PATH
193template <ck_fp8_
interpretation_t
interp,
bool stochastic_rounding = false>
194__host__ __device__
static inline fp8_storage_t cvt_float_to_fp8_scaled(
const float f,
float scale)
196 __is_interpret_supported(interp);
198 if constexpr(stochastic_rounding)
201 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
204 return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
217template <ck_fp8_
interpretation_t
interp,
bool stochastic_rounding = false>
221 __is_interpret_supported(interp);
223 if constexpr(stochastic_rounding)
226 rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
229 return cast_to_f8_from_f32_scaled<interp, stochastic_rounding>(f, rng, scale);
244template <ck_fp8_
interpretation_t
interp,
bool stochastic_rounding = false>
245__host__ __device__
static inline fp8_storage_t cvt_float_to_fp8_scaled(
const float f,
float scale)
250 "Only OCP interpretations are supported");
253 if constexpr(stochastic_rounding)
255 constexpr int seed = 1254739;
261 return cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f / scale, rng);
265 return cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f / scale, rng);
269 __hip_assert(
false &&
"FP8 type is not supported by current target device");
284template <ck_fp8_
interpretation_t
interp,
bool stochastic_rounding = false>
291 "Only OCP interpretations are supported");
294 if constexpr(stochastic_rounding)
296 constexpr int seed = 1254739;
302 return {cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[0] / scale, rng),
303 cast_to_f8<float, 3, 4, false, true, stochastic_rounding>(f[1] / scale, rng)};
307 return {cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[0] / scale, rng),
308 cast_to_f8<float, 2, 5, false, true, stochastic_rounding>(f[1] / scale, rng)};
312 __hip_assert(
false &&
"FP8 type is not supported by current target device");
322template <
typename Y,
typename X>
326template <
typename Y,
typename X>
333 return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
340 return bf8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
348 return f8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret>(x, scale)};
356 return bf8x2_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret>(x, scale)};
455 return f8_ocp_t{fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
463 fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
471 fp8_impl::cvt_float_to_fp8_scaled<f8_ocp_t::default_interpret, true>(x, scale)};
480 fp8_impl::cvt_float_to_fp8_scaled<bf8_ocp_t::default_interpret, true>(x, scale)};
float float2_t
Definition amd_ck_fp8.hpp:92
fp8_storage_t fp8x2_storage_t
Definition amd_ck_fp8.hpp:88
typename vector_type< float, 16 >::type float16_t
Definition dtype_vector.hpp:2148
__host__ __device__ f8x16_ocp_t mxf8_convert_sr< f8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition mxf8_utils.hpp:485
__host__ __device__ f8x2_ocp_t mxf8_convert_rne< f8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition mxf8_utils.hpp:345
__host__ __device__ constexpr Y mxf8_convert_rne(X x, float scale)
__host__ __device__ f8_ocp_t mxf8_convert_rne< f8_ocp_t, float >(float x, float scale)
Definition mxf8_utils.hpp:331
__host__ __device__ bf8_ocp_t mxf8_convert_sr< bf8_ocp_t, float >(float x, float scale)
Definition mxf8_utils.hpp:460
@ CK_E4M3_OCP
Definition amd_ck_fp8.hpp:71
@ CK_E5M2_OCP
Definition amd_ck_fp8.hpp:72
typename vector_type< f8_ocp_t, 32 >::type f8x32_ocp_t
Definition dtype_vector.hpp:2204
__host__ __device__ f8_ocp_t mxf8_convert_sr< f8_ocp_t, float >(float x, float scale)
Definition mxf8_utils.hpp:453
integral_constant< index_t, N > Number
Definition number.hpp:12
typename vector_type< bf8_ocp_t, 32 >::type bf8x32_ocp_t
Definition dtype_vector.hpp:2212
__host__ __device__ bf8x32_ocp_t mxf8_convert_rne< bf8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition mxf8_utils.hpp:430
typename vector_type< bf8_ocp_t, 2 >::type bf8x2_ocp_t
Definition dtype_vector.hpp:2208
__host__ __device__ bf8_ocp_t mxf8_convert_rne< bf8_ocp_t, float >(float x, float scale)
Definition mxf8_utils.hpp:338
typename vector_type< float, 2 >::type float2_t
Definition dtype_vector.hpp:2145
__host__ __device__ bf8x16_ocp_t mxf8_convert_sr< bf8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition mxf8_utils.hpp:508
__device__ index_t get_thread_global_1d_id()
Definition get_id.hpp:43
__host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed=seed_t)
Definition random_gen.hpp:19
typename vector_type< f8_ocp_t, 2 >::type f8x2_ocp_t
Definition dtype_vector.hpp:2200
typename vector_type< float, 32 >::type float32_t
Definition dtype_vector.hpp:2149
typename vector_type< f8_ocp_t, 16 >::type f8x16_ocp_t
Definition dtype_vector.hpp:2203
__host__ __device__ bf8x16_ocp_t mxf8_convert_rne< bf8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition mxf8_utils.hpp:384
__host__ __device__ f8x16_ocp_t mxf8_convert_rne< f8x16_ocp_t, float16_t >(float16_t x, float scale)
Definition mxf8_utils.hpp:361
__host__ __device__ bf8x32_ocp_t mxf8_convert_sr< bf8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition mxf8_utils.hpp:554
__host__ __device__ f8x32_ocp_t mxf8_convert_sr< f8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition mxf8_utils.hpp:531
__host__ __device__ constexpr Y mxf8_convert_sr(X x, float scale)
typename vector_type< bf8_ocp_t, 16 >::type bf8x16_ocp_t
Definition dtype_vector.hpp:2211
__host__ __device__ bf8x2_ocp_t mxf8_convert_sr< bf8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition mxf8_utils.hpp:476
__host__ __device__ constexpr Y bit_cast(const X &x)
Definition type.hpp:306
__host__ __device__ f8x32_ocp_t mxf8_convert_rne< f8x32_ocp_t, float32_t >(float32_t x, float scale)
Definition mxf8_utils.hpp:407
__host__ __device__ bf8x2_ocp_t mxf8_convert_rne< bf8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition mxf8_utils.hpp:353
__host__ __device__ f8x2_ocp_t mxf8_convert_sr< f8x2_ocp_t, float2_t >(float2_t x, float scale)
Definition mxf8_utils.hpp:468
unsigned char fp8_storage_t
Definition amd_ck_fp8.hpp:64
_W64 unsigned int uintptr_t
Definition stdint.h:164
unsigned int uint32_t
Definition stdint.h:126
Definition amd_ck_fp8.hpp:369
Definition amd_ck_fp8.hpp:323
Definition functional2.hpp:33