debug.hpp Source File

debug.hpp Source File#

Composable Kernel: debug.hpp Source File
utility/debug.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#ifndef UTILITY_DEBUG_HPP
5#define UTILITY_DEBUG_HPP
6#include "type.hpp"
7
8namespace ck {
9namespace debug {
10
11namespace detail {
12template <typename T, typename Enable = void>
14
15template <typename T>
17{
18 using type = float;
19 __host__ __device__ static void Print(const T& p) { printf("%.3f ", static_cast<type>(p)); }
20};
21
22template <>
23struct PrintAsType<ck::half_t, void>
24{
25 using type = float;
26 __host__ __device__ static void Print(const ck::half_t& p)
27 {
28 printf("%.3f ", static_cast<type>(p));
29 }
30};
31
32template <typename T>
34{
35 using type = int;
36 __host__ __device__ static void Print(const T& p) { printf("%d ", static_cast<type>(p)); }
37};
38} // namespace detail
39
40// Print at runtime the data in shared memory in 128 bytes per row format given shared mem pointer
41// and the number of elements. Can optionally specify strides between elements and how many bytes'
42// worth of data per row.
43//
44// Usage example:
45//
46// debug::print_shared(a_block_buf.p_data_, index_t(a_block_desc_k0_m_k1.GetElementSpaceSize()));
47//
48template <typename T, index_t element_stride = 1, index_t row_bytes = 128>
49__device__ void print_shared(T const* p_shared, index_t num_elements)
50{
51 constexpr index_t row_elements = row_bytes / sizeof(T);
52 static_assert((element_stride >= 1 && element_stride <= row_elements),
53 "element_stride should between [1, row_elements]");
54
55 index_t wgid = blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z;
56 index_t tid =
57 (threadIdx.z * (blockDim.x * blockDim.y)) + (threadIdx.y * blockDim.x) + threadIdx.x;
58
59 __syncthreads();
60
61 if(tid == 0)
62 {
63 printf("\nWorkgroup id %d, bytes per row %d, element stride %d\n\n",
64 wgid,
65 row_bytes,
66 element_stride);
67 for(index_t i = 0; i < num_elements; i += row_elements)
68 {
69 printf("elem %5d: ", i);
70 for(index_t j = 0; j < row_elements; j += element_stride)
71 {
72 detail::PrintAsType<T>::Print(p_shared[i + j]);
73 }
74
75 printf("\n");
76 }
77 printf("\n");
78 }
79
80 __syncthreads();
81}
82
83template <index_t... Ids>
84__device__ static bool is_thread_local_1d_id_idx()
85{
86 const auto tid = get_thread_local_1d_id();
87 return ((tid == Ids) || ...);
88}
89
90// Use `CK_PRINT<T1, T2, ...>()` to inspect values of type T1, T2, ...
91// Use `CK_PRINT<v1, v2, ...>()` to inspect constexpr values of val1, val2, ... of the same type
92// In a non-evaluated context, you can use `using _dummy = decltype(CK_PRINT<...>());`
93// Set BUILD_DEV to OFF to avoid enabling Werror
94template <auto... val>
95[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
96{
97}
98template <typename... type>
99[[deprecated("Help function to print value")]] inline constexpr void CK_PRINT()
100{
101}
102
103} // namespace debug
104} // namespace ck
105
106#endif // UTILITY_DEBUG_HPP
Definition utility/debug.hpp:11
Definition utility/debug.hpp:9
constexpr void CK_PRINT()
Definition utility/debug.hpp:95
__device__ void print_shared(T const *p_shared, index_t num_elements)
Definition utility/debug.hpp:49
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
_Float16 half_t
Definition data_type.hpp:31
std::enable_if< B, T > enable_if
Definition enable_if.hpp:24
__device__ index_t get_thread_local_1d_id()
Definition get_id.hpp:41
const GenericPointer< typename T::ValueType > T2 value
Definition pointer.h:1697
__host__ static __device__ void Print(const T &p)
Definition utility/debug.hpp:19
__host__ static __device__ void Print(const T &p)
Definition utility/debug.hpp:36
__host__ static __device__ void Print(const ck::half_t &p)
Definition utility/debug.hpp:26
float type
Definition utility/debug.hpp:25
Definition utility/debug.hpp:13
Definition type.hpp:187
Definition type.hpp:206