1a825fb2cSChristian Sigg //===- RocmRuntimeWrappers.cpp - MLIR ROCM runtime wrapper library --------===//
2a825fb2cSChristian Sigg //
3a825fb2cSChristian Sigg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a825fb2cSChristian Sigg // See https://llvm.org/LICENSE.txt for license information.
5a825fb2cSChristian Sigg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a825fb2cSChristian Sigg //
7a825fb2cSChristian Sigg //===----------------------------------------------------------------------===//
8a825fb2cSChristian Sigg //
9a825fb2cSChristian Sigg // Implements C wrappers around the ROCM library for easy linking in ORC jit.
10a825fb2cSChristian Sigg // Also adds some debugging helpers that are helpful when writing MLIR code to
11a825fb2cSChristian Sigg // run on GPUs.
12a825fb2cSChristian Sigg //
13a825fb2cSChristian Sigg //===----------------------------------------------------------------------===//
14a825fb2cSChristian Sigg 
15a825fb2cSChristian Sigg #include <cassert>
16a825fb2cSChristian Sigg #include <numeric>
17a825fb2cSChristian Sigg 
18a825fb2cSChristian Sigg #include "mlir/ExecutionEngine/CRunnerUtils.h"
19a825fb2cSChristian Sigg #include "llvm/ADT/ArrayRef.h"
20a825fb2cSChristian Sigg 
21a825fb2cSChristian Sigg #include "hip/hip_runtime.h"
22a825fb2cSChristian Sigg 
23a825fb2cSChristian Sigg #define HIP_REPORT_IF_ERROR(expr)                                              \
24a825fb2cSChristian Sigg   [](hipError_t result) {                                                      \
25a825fb2cSChristian Sigg     if (!result)                                                               \
26a825fb2cSChristian Sigg       return;                                                                  \
27a825fb2cSChristian Sigg     const char *name = hipGetErrorName(result);                                \
28a825fb2cSChristian Sigg     if (!name)                                                                 \
29a825fb2cSChristian Sigg       name = "<unknown>";                                                      \
30a825fb2cSChristian Sigg     fprintf(stderr, "'%s' failed with '%s'\n", #expr, name);                   \
31a825fb2cSChristian Sigg   }(expr)
32a825fb2cSChristian Sigg 
33*84718d37SKrzysztof Drewniak thread_local static int32_t defaultDevice = 0;
34*84718d37SKrzysztof Drewniak 
35a825fb2cSChristian Sigg // Sets the `Context` for the duration of the instance and restores the previous
36a825fb2cSChristian Sigg // context on destruction.
37a825fb2cSChristian Sigg class ScopedContext {
38a825fb2cSChristian Sigg public:
ScopedContext()39a825fb2cSChristian Sigg   ScopedContext() {
40*84718d37SKrzysztof Drewniak     // Static reference to HIP primary context for device ordinal defaultDevice.
41a825fb2cSChristian Sigg     static hipCtx_t context = [] {
42a825fb2cSChristian Sigg       HIP_REPORT_IF_ERROR(hipInit(/*flags=*/0));
43a825fb2cSChristian Sigg       hipDevice_t device;
44*84718d37SKrzysztof Drewniak       HIP_REPORT_IF_ERROR(hipDeviceGet(&device, /*ordinal=*/defaultDevice));
45a825fb2cSChristian Sigg       hipCtx_t ctx;
46a825fb2cSChristian Sigg       HIP_REPORT_IF_ERROR(hipDevicePrimaryCtxRetain(&ctx, device));
47a825fb2cSChristian Sigg       return ctx;
48a825fb2cSChristian Sigg     }();
49a825fb2cSChristian Sigg 
50a825fb2cSChristian Sigg     HIP_REPORT_IF_ERROR(hipCtxPushCurrent(context));
51a825fb2cSChristian Sigg   }
52a825fb2cSChristian Sigg 
~ScopedContext()53a825fb2cSChristian Sigg   ~ScopedContext() { HIP_REPORT_IF_ERROR(hipCtxPopCurrent(nullptr)); }
54a825fb2cSChristian Sigg };
55a825fb2cSChristian Sigg 
mgpuModuleLoad(void * data)56a825fb2cSChristian Sigg extern "C" hipModule_t mgpuModuleLoad(void *data) {
57a825fb2cSChristian Sigg   ScopedContext scopedContext;
58a825fb2cSChristian Sigg   hipModule_t module = nullptr;
59a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data));
60a825fb2cSChristian Sigg   return module;
61a825fb2cSChristian Sigg }
62a825fb2cSChristian Sigg 
mgpuModuleUnload(hipModule_t module)63a825fb2cSChristian Sigg extern "C" void mgpuModuleUnload(hipModule_t module) {
64a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipModuleUnload(module));
65a825fb2cSChristian Sigg }
66a825fb2cSChristian Sigg 
mgpuModuleGetFunction(hipModule_t module,const char * name)67a825fb2cSChristian Sigg extern "C" hipFunction_t mgpuModuleGetFunction(hipModule_t module,
68a825fb2cSChristian Sigg                                                const char *name) {
69a825fb2cSChristian Sigg   hipFunction_t function = nullptr;
70a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipModuleGetFunction(&function, module, name));
71a825fb2cSChristian Sigg   return function;
72a825fb2cSChristian Sigg }
73a825fb2cSChristian Sigg 
74a825fb2cSChristian Sigg // The wrapper uses intptr_t instead of ROCM's unsigned int to match
75a825fb2cSChristian Sigg // the type of MLIR's index type. This avoids the need for casts in the
76a825fb2cSChristian Sigg // generated MLIR code.
mgpuLaunchKernel(hipFunction_t function,intptr_t gridX,intptr_t gridY,intptr_t gridZ,intptr_t blockX,intptr_t blockY,intptr_t blockZ,int32_t smem,hipStream_t stream,void ** params,void ** extra)77a825fb2cSChristian Sigg extern "C" void mgpuLaunchKernel(hipFunction_t function, intptr_t gridX,
78a825fb2cSChristian Sigg                                  intptr_t gridY, intptr_t gridZ,
79a825fb2cSChristian Sigg                                  intptr_t blockX, intptr_t blockY,
80a825fb2cSChristian Sigg                                  intptr_t blockZ, int32_t smem,
81a825fb2cSChristian Sigg                                  hipStream_t stream, void **params,
82a825fb2cSChristian Sigg                                  void **extra) {
83a825fb2cSChristian Sigg   ScopedContext scopedContext;
84a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ,
85a825fb2cSChristian Sigg                                             blockX, blockY, blockZ, smem,
86a825fb2cSChristian Sigg                                             stream, params, extra));
87a825fb2cSChristian Sigg }
88a825fb2cSChristian Sigg 
mgpuStreamCreate()89a825fb2cSChristian Sigg extern "C" hipStream_t mgpuStreamCreate() {
90a825fb2cSChristian Sigg   ScopedContext scopedContext;
91a825fb2cSChristian Sigg   hipStream_t stream = nullptr;
92a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipStreamCreate(&stream));
93a825fb2cSChristian Sigg   return stream;
94a825fb2cSChristian Sigg }
95a825fb2cSChristian Sigg 
mgpuStreamDestroy(hipStream_t stream)96a825fb2cSChristian Sigg extern "C" void mgpuStreamDestroy(hipStream_t stream) {
97a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipStreamDestroy(stream));
98a825fb2cSChristian Sigg }
99a825fb2cSChristian Sigg 
mgpuStreamSynchronize(hipStream_t stream)100a825fb2cSChristian Sigg extern "C" void mgpuStreamSynchronize(hipStream_t stream) {
101a825fb2cSChristian Sigg   return HIP_REPORT_IF_ERROR(hipStreamSynchronize(stream));
102a825fb2cSChristian Sigg }
103a825fb2cSChristian Sigg 
mgpuStreamWaitEvent(hipStream_t stream,hipEvent_t event)104a825fb2cSChristian Sigg extern "C" void mgpuStreamWaitEvent(hipStream_t stream, hipEvent_t event) {
105a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipStreamWaitEvent(stream, event, /*flags=*/0));
106a825fb2cSChristian Sigg }
107a825fb2cSChristian Sigg 
mgpuEventCreate()108a825fb2cSChristian Sigg extern "C" hipEvent_t mgpuEventCreate() {
109a825fb2cSChristian Sigg   ScopedContext scopedContext;
110a825fb2cSChristian Sigg   hipEvent_t event = nullptr;
111a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipEventCreateWithFlags(&event, hipEventDisableTiming));
112a825fb2cSChristian Sigg   return event;
113a825fb2cSChristian Sigg }
114a825fb2cSChristian Sigg 
mgpuEventDestroy(hipEvent_t event)115a825fb2cSChristian Sigg extern "C" void mgpuEventDestroy(hipEvent_t event) {
116a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipEventDestroy(event));
117a825fb2cSChristian Sigg }
118a825fb2cSChristian Sigg 
mgpuEventSynchronize(hipEvent_t event)119a825fb2cSChristian Sigg extern "C" void mgpuEventSynchronize(hipEvent_t event) {
120a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipEventSynchronize(event));
121a825fb2cSChristian Sigg }
122a825fb2cSChristian Sigg 
mgpuEventRecord(hipEvent_t event,hipStream_t stream)123a825fb2cSChristian Sigg extern "C" void mgpuEventRecord(hipEvent_t event, hipStream_t stream) {
124a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipEventRecord(event, stream));
125a825fb2cSChristian Sigg }
126a825fb2cSChristian Sigg 
mgpuMemAlloc(uint64_t sizeBytes,hipStream_t)127a825fb2cSChristian Sigg extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, hipStream_t /*stream*/) {
128a825fb2cSChristian Sigg   ScopedContext scopedContext;
129a825fb2cSChristian Sigg   void *ptr;
130a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipMalloc(&ptr, sizeBytes));
131a825fb2cSChristian Sigg   return ptr;
132a825fb2cSChristian Sigg }
133a825fb2cSChristian Sigg 
mgpuMemFree(void * ptr,hipStream_t)134a825fb2cSChristian Sigg extern "C" void mgpuMemFree(void *ptr, hipStream_t /*stream*/) {
135a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipFree(ptr));
136a825fb2cSChristian Sigg }
137a825fb2cSChristian Sigg 
mgpuMemcpy(void * dst,void * src,size_t sizeBytes,hipStream_t stream)138361458b1SLoren Maggiore extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,
139a825fb2cSChristian Sigg                            hipStream_t stream) {
140a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(
141a825fb2cSChristian Sigg       hipMemcpyAsync(dst, src, sizeBytes, hipMemcpyDefault, stream));
142a825fb2cSChristian Sigg }
143a825fb2cSChristian Sigg 
mgpuMemset32(void * dst,int value,size_t count,hipStream_t stream)144361458b1SLoren Maggiore extern "C" void mgpuMemset32(void *dst, int value, size_t count,
145361458b1SLoren Maggiore                              hipStream_t stream) {
146361458b1SLoren Maggiore   HIP_REPORT_IF_ERROR(hipMemsetD32Async(reinterpret_cast<hipDeviceptr_t>(dst),
147361458b1SLoren Maggiore                                         value, count, stream));
148361458b1SLoren Maggiore }
149a825fb2cSChristian Sigg /// Helper functions for writing mlir example code
150a825fb2cSChristian Sigg 
151a825fb2cSChristian Sigg // Allows to register byte array with the ROCM runtime. Helpful until we have
152a825fb2cSChristian Sigg // transfer functions implemented.
mgpuMemHostRegister(void * ptr,uint64_t sizeBytes)153a825fb2cSChristian Sigg extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
154a825fb2cSChristian Sigg   ScopedContext scopedContext;
155a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipHostRegister(ptr, sizeBytes, /*flags=*/0));
156a825fb2cSChristian Sigg }
157a825fb2cSChristian Sigg 
158a825fb2cSChristian Sigg // Allows to register a MemRef with the ROCm runtime. Helpful until we have
159a825fb2cSChristian Sigg // transfer functions implemented.
160a825fb2cSChristian Sigg extern "C" void
mgpuMemHostRegisterMemRef(int64_t rank,StridedMemRefType<char,1> * descriptor,int64_t elementSizeBytes)161a825fb2cSChristian Sigg mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
162a825fb2cSChristian Sigg                           int64_t elementSizeBytes) {
163a825fb2cSChristian Sigg 
164a825fb2cSChristian Sigg   llvm::SmallVector<int64_t, 4> denseStrides(rank);
165a825fb2cSChristian Sigg   llvm::ArrayRef<int64_t> sizes(descriptor->sizes, rank);
166a825fb2cSChristian Sigg   llvm::ArrayRef<int64_t> strides(sizes.end(), rank);
167a825fb2cSChristian Sigg 
168a825fb2cSChristian Sigg   std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
169a825fb2cSChristian Sigg                    std::multiplies<int64_t>());
170a825fb2cSChristian Sigg   auto sizeBytes = denseStrides.front() * elementSizeBytes;
171a825fb2cSChristian Sigg 
172a825fb2cSChristian Sigg   // Only densely packed tensors are currently supported.
173a825fb2cSChristian Sigg   std::rotate(denseStrides.begin(), denseStrides.begin() + 1,
174a825fb2cSChristian Sigg               denseStrides.end());
175a825fb2cSChristian Sigg   denseStrides.back() = 1;
176a825fb2cSChristian Sigg   assert(strides == llvm::makeArrayRef(denseStrides));
177a825fb2cSChristian Sigg 
178a825fb2cSChristian Sigg   auto ptr = descriptor->data + descriptor->offset * elementSizeBytes;
179a825fb2cSChristian Sigg   mgpuMemHostRegister(ptr, sizeBytes);
180a825fb2cSChristian Sigg }
181a825fb2cSChristian Sigg 
182a825fb2cSChristian Sigg template <typename T>
mgpuMemGetDevicePointer(T * hostPtr,T ** devicePtr)183a825fb2cSChristian Sigg void mgpuMemGetDevicePointer(T *hostPtr, T **devicePtr) {
184a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(hipSetDevice(0));
185a825fb2cSChristian Sigg   HIP_REPORT_IF_ERROR(
186a825fb2cSChristian Sigg       hipHostGetDevicePointer((void **)devicePtr, hostPtr, /*flags=*/0));
187a825fb2cSChristian Sigg }
188a825fb2cSChristian Sigg 
189a825fb2cSChristian Sigg extern "C" StridedMemRefType<float, 1>
mgpuMemGetDeviceMemRef1dFloat(float * allocated,float * aligned,int64_t offset,int64_t size,int64_t stride)190a825fb2cSChristian Sigg mgpuMemGetDeviceMemRef1dFloat(float *allocated, float *aligned, int64_t offset,
191a825fb2cSChristian Sigg                               int64_t size, int64_t stride) {
192a825fb2cSChristian Sigg   float *devicePtr = nullptr;
193a825fb2cSChristian Sigg   mgpuMemGetDevicePointer(aligned, &devicePtr);
194a825fb2cSChristian Sigg   return {devicePtr, devicePtr, offset, {size}, {stride}};
195a825fb2cSChristian Sigg }
196a825fb2cSChristian Sigg 
197a825fb2cSChristian Sigg extern "C" StridedMemRefType<int32_t, 1>
mgpuMemGetDeviceMemRef1dInt32(int32_t * allocated,int32_t * aligned,int64_t offset,int64_t size,int64_t stride)198a825fb2cSChristian Sigg mgpuMemGetDeviceMemRef1dInt32(int32_t *allocated, int32_t *aligned,
199a825fb2cSChristian Sigg                               int64_t offset, int64_t size, int64_t stride) {
200a825fb2cSChristian Sigg   int32_t *devicePtr = nullptr;
201a825fb2cSChristian Sigg   mgpuMemGetDevicePointer(aligned, &devicePtr);
202a825fb2cSChristian Sigg   return {devicePtr, devicePtr, offset, {size}, {stride}};
203a825fb2cSChristian Sigg }
204*84718d37SKrzysztof Drewniak 
mgpuSetDefaultDevice(int32_t device)205*84718d37SKrzysztof Drewniak extern "C" void mgpuSetDefaultDevice(int32_t device) {
206*84718d37SKrzysztof Drewniak   defaultDevice = device;
207*84718d37SKrzysztof Drewniak   HIP_REPORT_IF_ERROR(hipSetDevice(device));
208*84718d37SKrzysztof Drewniak }
209