flush_cache.hpp Source File

flush_cache.hpp Source File#

Composable Kernel: flush_cache.hpp Source File
flush_cache.hpp
Go to the documentation of this file.
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3
4#pragma once
5
6#include <hip/hip_runtime.h>
7#include <numeric>
8#include <set>
9#include <vector>
10
11#include "ck/ck.hpp"
12#include "ck/utility/env.hpp"
13#include "ck/stream_config.hpp"
16namespace ck {
17namespace utility {
18
19template <typename Argument, typename AsDataType, typename BsDataType, typename DsDataType>
21{
22 static constexpr index_t NumAs = AsDataType::Size();
23 static constexpr index_t NumBs = BsDataType::Size();
24 static constexpr index_t NumDs = DsDataType::Size();
25
26 using AsGridPointer = decltype(Argument::p_as_grid);
27 using BsGridPointer = decltype(Argument::p_bs_grid);
28 using DsGridPointer = decltype(Argument::p_ds_grid);
29
32 std::size_t rotating_count_hint,
33 std::array<std::size_t, NumAs> size_as_,
34 std::array<std::size_t, NumBs> size_bs_,
35 std::array<std::size_t, NumDs> size_ds_)
36 : arg(arg_),
37 rotating_count(rotating_count_hint),
38 size_as(size_as_),
39 size_bs(size_bs_),
40 size_ds(size_ds_)
41 {
42 p_as_grids.push_back(arg.p_as_grid);
43 p_bs_grids.push_back(arg.p_bs_grid);
44 p_ds_grids.push_back(arg.p_ds_grid);
45
46 // limit the rotating count to prevent oom
47 const uint64_t footprint = std::accumulate(size_as.begin(), size_as.end(), 0UL) +
48 std::accumulate(size_bs.begin(), size_bs.end(), 0UL) +
49 std::accumulate(size_ds.begin(), size_ds.end(), 0UL);
50 const uint64_t max_rotating_count = (1ULL << 31) / footprint;
51 rotating_count = std::min(rotating_count, max_rotating_count);
52
53 for(size_t i = 1; i < rotating_count; i++)
54 {
55 {
56 AsGridPointer as_buffer;
57 static_for<0, NumAs, 1>{}([&](auto j) {
58 void* pADeviceBuf;
59 hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_as_[j]));
60 hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
61 static_cast<const void*>(p_as_grids[0][j]),
62 size_as_[j],
63 hipMemcpyDeviceToDevice));
64 using ADataType = remove_cvref_t<tuple_element_t<j.value, AsDataType>>;
65
66 as_buffer(j) = static_cast<const ADataType*>(pADeviceBuf);
67 });
68 p_as_grids.push_back(as_buffer);
69 }
70
71 {
72 BsGridPointer bs_buffer;
73 static_for<0, NumBs, 1>{}([&](auto j) {
74 void* pBDeviceBuf;
75 hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_bs_[j]));
76 hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
77 static_cast<const void*>(p_bs_grids[0][j]),
78 size_bs_[j],
79 hipMemcpyDeviceToDevice));
80 using BDataType = remove_cvref_t<tuple_element_t<j.value, BsDataType>>;
81
82 bs_buffer(j) = static_cast<const BDataType*>(pBDeviceBuf);
83 });
84 p_bs_grids.push_back(bs_buffer);
85 }
86
87 {
88 DsGridPointer ds_buffer;
89 static_for<0, NumDs, 1>{}([&](auto j) {
90 void* pDDeviceBuf;
91 hip_check_error(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j]));
92 hip_check_error(hipMemcpy(static_cast<void*>(pDDeviceBuf),
93 static_cast<const void*>(p_ds_grids[0][j]),
94 size_ds_[j],
95 hipMemcpyDeviceToDevice));
96
97 using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
98
99 ds_buffer(j) = static_cast<const DDataType*>(pDDeviceBuf);
100 });
101
102 p_ds_grids.push_back(ds_buffer);
103 }
104 }
105 }
106
107 void Next()
108 {
109 if(rotating_count > 1)
110 {
111 std::size_t idx = iter++ % rotating_count;
112 arg.p_as_grid = p_as_grids[idx];
113 arg.p_bs_grid = p_bs_grids[idx];
114 arg.p_ds_grid = p_ds_grids[idx];
115 }
116 }
117 void Print()
118 {
119 std::cout << "RotatingMemWrapperMultiD: { size_a: {";
121 [&](auto j) { std::cout << size_as[j] << (j.value < NumAs - 1 ? ", " : ""); });
122 std::cout << "}, size_b: {";
124 [&](auto j) { std::cout << size_bs[j] << (j.value < NumBs - 1 ? ", " : ""); });
125 std::cout << "}, rotating_count: " << rotating_count << "}" << std::endl;
126 }
128 {
129 if(rotating_count > 1)
130 {
131 // restore ptr
132 arg.p_as_grid = p_as_grids[0];
133 arg.p_bs_grid = p_bs_grids[0];
134 arg.p_ds_grid = p_ds_grids[0];
135
136 // free device mem
137 for(size_t i = 1; i < rotating_count; i++)
138 {
139 static_for<0, NumAs, 1>{}([&](auto j) {
140 using ADataType = remove_cvref_t<tuple_element_t<j.value, AsDataType>>;
142 hipFree(static_cast<void*>(const_cast<ADataType*>(p_as_grids[i][j]))));
143 });
144
145 static_for<0, NumBs, 1>{}([&](auto j) {
146 using BDataType = remove_cvref_t<tuple_element_t<j.value, BsDataType>>;
148 hipFree(static_cast<void*>(const_cast<BDataType*>(p_bs_grids[i][j]))));
149 });
150
151 static_for<0, NumDs, 1>{}([&](auto j) {
152 using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
154 hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j]))));
155 });
156 }
157 }
158 }
159
160 private:
161 Argument& arg;
162 std::size_t iter = 0;
163 std::size_t rotating_count = 1;
164 std::array<std::size_t, NumAs> size_as = {0};
165 std::array<std::size_t, NumBs> size_bs = {0};
166 std::array<std::size_t, NumDs> size_ds = {0};
167 std::vector<AsGridPointer> p_as_grids;
168 std::vector<BsGridPointer> p_bs_grids;
169 std::vector<DsGridPointer> p_ds_grids;
170};
171
172template <typename Argument, typename DsDataType>
174{
175 static constexpr index_t NumDs = DsDataType::Size();
176
177 using ADataType = decltype(Argument::p_a_grid);
178 using BDataType = decltype(Argument::p_b_grid);
179 using DsGridPointer = decltype(Argument::p_ds_grid);
180
183 std::size_t rotating_count_hint,
184 std::size_t size_a_,
185 std::size_t size_b_,
186 std::array<std::size_t, NumDs> size_ds_)
187 : arg(arg_),
188 rotating_count(rotating_count_hint),
189 size_a(size_a_),
190 size_b(size_b_),
191 size_ds(size_ds_)
192 {
193 p_a_grids.push_back(arg.p_a_grid);
194 p_b_grids.push_back(arg.p_b_grid);
195 p_ds_grids.push_back(arg.p_ds_grid);
196
197 // limit the rotating count to prevent oom
198 const uint64_t footprint =
199 std::accumulate(size_ds.begin(), size_ds.end(), 0UL) + (size_a + size_b);
200 const uint64_t max_rotating_count = (1ULL << 31) / footprint;
201 rotating_count = std::min(rotating_count, max_rotating_count);
202
203 for(size_t i = 1; i < rotating_count; i++)
204 {
205 {
206 void* pADeviceBuf;
207 hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
208 hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
209 const_cast<void*>(p_a_grids[0]),
210 size_a_,
211 hipMemcpyDeviceToDevice));
212 p_a_grids.push_back(pADeviceBuf);
213 }
214
215 {
216 void* pBDeviceBuf;
217 hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
218 hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
219 const_cast<void*>(p_b_grids[0]),
220 size_b_,
221 hipMemcpyDeviceToDevice));
222 p_b_grids.push_back(pBDeviceBuf);
223 }
224
225 {
226
227 DsGridPointer ds_buffer;
228 static_for<0, NumDs, 1>{}([&](auto j) {
229 void* pDDeviceBuf;
230 hip_check_error(hipMalloc(static_cast<void**>(&pDDeviceBuf), size_ds_[j]));
231 hip_check_error(hipMemcpy(static_cast<void*>(pDDeviceBuf),
232 static_cast<const void*>(p_ds_grids[0][j]),
233 size_ds_[j],
234 hipMemcpyDeviceToDevice));
235
236 using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
237
238 ds_buffer(j) = static_cast<const DDataType*>(pDDeviceBuf);
239 });
240
241 p_ds_grids.push_back(ds_buffer);
242 }
243 }
244 }
245
246 void Next()
247 {
248 if(rotating_count > 1)
249 {
250 std::size_t idx = iter++ % rotating_count;
251 arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[idx]);
252 arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[idx]);
253 arg.p_ds_grid = p_ds_grids[idx];
254 }
255 }
256 void Print()
257 {
258 std::cout << "RotatingMemWrapperMultiD: { size_a: " << size_a << ", size_b: " << size_b
259 << ", rotating_count: " << rotating_count << "}" << std::endl;
260 }
262 {
263 if(rotating_count > 1)
264 {
265 // restore ptr
266 arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[0]);
267 arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[0]);
268 arg.p_ds_grid = p_ds_grids[0];
269
270 // free device mem
271 for(size_t i = 1; i < rotating_count; i++)
272 {
273 hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
274 hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
275
276 static_for<0, NumDs, 1>{}([&](auto j) {
277 using DDataType = remove_cvref_t<tuple_element_t<j.value, DsDataType>>;
279 hipFree(static_cast<void*>(const_cast<DDataType*>(p_ds_grids[i][j]))));
280 });
281 }
282 }
283 }
284
285 private:
286 Argument& arg;
287 std::size_t iter = 0;
288 std::size_t rotating_count = 1;
289 std::size_t size_a = 0;
290 std::size_t size_b = 0;
291 std::array<std::size_t, NumDs> size_ds = {0};
292 std::vector<const void*> p_a_grids;
293 std::vector<const void*> p_b_grids;
294 std::vector<DsGridPointer> p_ds_grids;
295};
296
297template <typename Argument>
299{
300 using ADataType = decltype(Argument::p_a_grid);
301 using BDataType = decltype(Argument::p_b_grid);
302
304 RotatingMemWrapper(Argument& arg_,
305 std::size_t rotating_count_hint,
306 std::size_t size_a_,
307 std::size_t size_b_)
308 : arg(arg_), rotating_count(rotating_count_hint), size_a(size_a_), size_b(size_b_)
309 {
310 p_a_grids.push_back(arg.p_a_grid);
311 p_b_grids.push_back(arg.p_b_grid);
312
313 // limit the rotating count to prevent oom
314 const uint64_t footprint = (size_a + size_b);
315 const uint64_t max_rotating_count = (1ULL << 31) / footprint;
316 rotating_count = std::min(rotating_count, max_rotating_count);
317
318 for(size_t i = 1; i < rotating_count; i++)
319 {
320 {
321 void* pADeviceBuf;
322 hip_check_error(hipMalloc(static_cast<void**>(&pADeviceBuf), size_a_));
323 hip_check_error(hipMemcpy(static_cast<void*>(pADeviceBuf),
324 const_cast<void*>(p_a_grids[0]),
325 size_a_,
326 hipMemcpyDeviceToDevice));
327 p_a_grids.push_back(pADeviceBuf);
328 }
329
330 {
331 void* pBDeviceBuf;
332 hip_check_error(hipMalloc(static_cast<void**>(&pBDeviceBuf), size_b_));
333 hip_check_error(hipMemcpy(static_cast<void*>(pBDeviceBuf),
334 const_cast<void*>(p_b_grids[0]),
335 size_b_,
336 hipMemcpyDeviceToDevice));
337 p_b_grids.push_back(pBDeviceBuf);
338 }
339 }
340 }
341
342 void Next()
343 {
344 if(rotating_count > 1)
345 {
346 std::size_t idx = iter++ % rotating_count;
347 arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[idx]);
348 arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[idx]);
349 }
350 }
351 void Print()
352 {
353 std::cout << "RotatingMemWrapper: { size_a: " << size_a << ", size_b: " << size_b
354 << ", rotating_count: " << rotating_count << "}" << std::endl;
355 }
357 {
358 if(rotating_count > 1)
359 {
360 // restore ptr
361 arg.p_a_grid = reinterpret_cast<ADataType>(p_a_grids[0]);
362 arg.p_b_grid = reinterpret_cast<BDataType>(p_b_grids[0]);
363
364 // free device mem
365 for(size_t i = 1; i < rotating_count; i++)
366 {
367 hip_check_error(hipFree(const_cast<void*>(p_a_grids[i])));
368 hip_check_error(hipFree(const_cast<void*>(p_b_grids[i])));
369 }
370 }
371 }
372
373 private:
374 Argument& arg;
375 std::size_t iter = 0;
376 std::size_t rotating_count = 1;
377 std::size_t size_a = 0;
378 std::size_t size_b = 0;
379 std::vector<const void*> p_a_grids;
380 std::vector<const void*> p_b_grids;
381};
382
383inline void flush_icache()
384{
385 hipDeviceProp_t deviceProps;
386 hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
387 int32_t gpu_block3 = deviceProps.multiProcessorCount * 60;
388
389 ck::flush_icache<<<dim3(gpu_block3), dim3(64), 0, nullptr>>>();
390 hip_check_error(hipGetLastError());
391}
392// if TimePrePress == false, return time does not include preprocess's time
393template <bool TimePreprocess,
394 typename GemmArgs,
395 typename... Args,
396 typename F,
397 typename PreProcessFunc>
399 PreProcessFunc preprocess,
400 F kernel,
401 dim3 grid_dim,
402 dim3 block_dim,
403 std::size_t lds_byte,
404 GemmArgs& gemm_args,
405 Args... args)
406{
407#if CK_TIME_KERNEL
408#define MEDIAN 0
409 if(stream_config.time_kernel_)
410 {
411 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
412 {
413 printf("%s: grid_dim {%u, %u, %u}, block_dim {%u, %u, %u} \n",
414 __func__,
415 grid_dim.x,
416 grid_dim.y,
417 grid_dim.z,
418 block_dim.x,
419 block_dim.y,
420 block_dim.z);
421
422 printf("Warm up %d times\n", stream_config.cold_niters_);
423 }
424 // warm up
425 for(int i = 0; i < stream_config.cold_niters_; ++i)
426 {
427 kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
428 hip_check_error(hipGetLastError());
429 }
430
431 const int nrepeat = stream_config.nrepeat_;
432 if(nrepeat == 0)
433 {
434 return 0.0;
435 }
436 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
437 {
438 printf("Start running %d times...\n", nrepeat);
439 }
440
441#if MEDIAN
442 std::set<float> times;
443#else
444 float total_time = 0;
445#endif
446 hipEvent_t start, stop;
447
448 hip_check_error(hipEventCreate(&start));
449 hip_check_error(hipEventCreate(&stop));
450
451 hip_check_error(hipDeviceSynchronize());
452 hip_check_error(hipEventRecord(start, stream_config.stream_id_));
453
454 for(int i = 0; i < nrepeat; ++i)
455 {
456 if constexpr(!TimePreprocess)
457 {
458 preprocess();
459 }
460
461 // hipEvent_t start, stop;
462
463 // hip_check_error(hipEventCreate(&start));
464 // hip_check_error(hipEventCreate(&stop));
465
466 // hip_check_error(hipDeviceSynchronize());
467 // hip_check_error(hipEventRecord(start, stream_config.stream_id_));
468 // calculate preprocess time
469 if constexpr(TimePreprocess)
470 {
471 preprocess();
472 }
473 // run real kernel
474 kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
475 hip_check_error(hipGetLastError());
476 // end real kernel
477
478 // hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
479 // hip_check_error(hipEventSynchronize(stop));
480 // float cur_time = 0;
481 // hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
482 // #if MEDIAN
483 // times.insert(cur_time);
484 // #else
485 // total_time += cur_time;
486 // #endif
487
488#if !defined(CK_USE_WMMA)
489 if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
490 {
491 // std::cout << "i: " << i << " cur_time: " << cur_time << std::endl;
492
493 printf("gemm_args.p_a_grid: %p, gemm_args.p_b_grid:%p\n",
494 static_cast<const void*>(gemm_args.p_a_grid),
495 static_cast<const void*>(gemm_args.p_b_grid));
496 }
497#endif
498 }
499 hip_check_error(hipEventRecord(stop, stream_config.stream_id_));
500 hip_check_error(hipEventSynchronize(stop));
501 float cur_time = 0;
502 hip_check_error(hipEventElapsedTime(&cur_time, start, stop));
503#if MEDIAN
504 times.insert(cur_time);
505#else
506 total_time += cur_time;
507#endif
508
509#if MEDIAN
510 auto mid = times.begin();
511 std::advance(mid, (nrepeat - 1) / 2);
512 if(nrepeat % 2 == 1)
513 {
514 return *mid;
515 }
516 else
517 {
518 auto mid_next = mid;
519 std::advance(mid_next, 1);
520 return (*mid + *mid_next) / 2;
521 }
522#else
523 // return total_time / nrepeat;
524 hipDeviceProp_t deviceProps;
525 hip_check_error(hipGetDeviceProperties(&deviceProps, 0));
526 float preprocess_offset = deviceProps.multiProcessorCount == 80 ? 0.005 : 0.01;
527 return (total_time - preprocess_offset * nrepeat) / nrepeat;
528#endif
529 }
530 else
531 {
532 preprocess();
533 kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
534 hip_check_error(hipGetLastError());
535
536 return 0;
537 }
538#else
539 kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(gemm_args, args...);
540 hip_check_error(hipGetLastError());
541
542 return 0;
543#endif
544}
545
546} // namespace utility
547} // namespace ck
void hip_check_error(hipError_t x)
Definition host_utility/hip_check_error.hpp:10
float launch_and_time_kernel_with_preprocess(const StreamConfig &stream_config, PreProcessFunc preprocess, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition host_utility/kernel_launch.hpp:91
Definition flush_cache.hpp:17
void flush_icache()
Definition flush_cache.hpp:383
Definition ck.hpp:268
int32_t index_t
Definition ck.hpp:299
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition type.hpp:297
typename tuple_element< I, TTuple >::type tuple_element_t
Definition utility/tuple.hpp:208
bool EnvIsEnabled(EnvVar)
Definition utility/env.hpp:140
signed int int32_t
Definition stdint.h:123
unsigned __int64 uint64_t
Definition stdint.h:136
Definition ck/stream_config.hpp:10
int cold_niters_
Definition ck/stream_config.hpp:14
bool time_kernel_
Definition ck/stream_config.hpp:12
int nrepeat_
Definition ck/stream_config.hpp:15
hipStream_t stream_id_
Definition ck/stream_config.hpp:11
Definition functional2.hpp:33
decltype(Argument::p_a_grid) ADataType
Definition flush_cache.hpp:300
RotatingMemWrapper(Argument &arg_, std::size_t rotating_count_hint, std::size_t size_a_, std::size_t size_b_)
Definition flush_cache.hpp:304
decltype(Argument::p_b_grid) BDataType
Definition flush_cache.hpp:301
~RotatingMemWrapper()
Definition flush_cache.hpp:356
void Print()
Definition flush_cache.hpp:351
void Next()
Definition flush_cache.hpp:342
static constexpr index_t NumBs
Definition flush_cache.hpp:23
RotatingMemWrapperMultiABD(Argument &arg_, std::size_t rotating_count_hint, std::array< std::size_t, NumAs > size_as_, std::array< std::size_t, NumBs > size_bs_, std::array< std::size_t, NumDs > size_ds_)
Definition flush_cache.hpp:31
static constexpr index_t NumDs
Definition flush_cache.hpp:24
decltype(Argument::p_bs_grid) BsGridPointer
Definition flush_cache.hpp:27
decltype(Argument::p_ds_grid) DsGridPointer
Definition flush_cache.hpp:28
void Print()
Definition flush_cache.hpp:117
void Next()
Definition flush_cache.hpp:107
static constexpr index_t NumAs
Definition flush_cache.hpp:22
decltype(Argument::p_as_grid) AsGridPointer
Definition flush_cache.hpp:26
~RotatingMemWrapperMultiABD()
Definition flush_cache.hpp:127
decltype(Argument::p_b_grid) BDataType
Definition flush_cache.hpp:178
void Print()
Definition flush_cache.hpp:256
static constexpr index_t NumDs
Definition flush_cache.hpp:175
decltype(Argument::p_a_grid) ADataType
Definition flush_cache.hpp:177
~RotatingMemWrapperMultiD()
Definition flush_cache.hpp:261
decltype(Argument::p_ds_grid) DsGridPointer
Definition flush_cache.hpp:179
RotatingMemWrapperMultiD(Argument &arg_, std::size_t rotating_count_hint, std::size_t size_a_, std::size_t size_b_, std::array< std::size_t, NumDs > size_ds_)
Definition flush_cache.hpp:182
void Next()
Definition flush_cache.hpp:246
#define CK_ENV(name)
Definition utility/env.hpp:129