1a825fb2cSChristian Sigg //===- RocmRuntimeWrappers.cpp - MLIR ROCM runtime wrapper library --------===//
2a825fb2cSChristian Sigg //
3a825fb2cSChristian Sigg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4a825fb2cSChristian Sigg // See https://llvm.org/LICENSE.txt for license information.
5a825fb2cSChristian Sigg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6a825fb2cSChristian Sigg //
7a825fb2cSChristian Sigg //===----------------------------------------------------------------------===//
8a825fb2cSChristian Sigg //
9a825fb2cSChristian Sigg // Implements C wrappers around the ROCM library for easy linking in ORC jit.
10a825fb2cSChristian Sigg // Also adds some debugging helpers that are helpful when writing MLIR code to
11a825fb2cSChristian Sigg // run on GPUs.
12a825fb2cSChristian Sigg //
13a825fb2cSChristian Sigg //===----------------------------------------------------------------------===//
14a825fb2cSChristian Sigg
15a825fb2cSChristian Sigg #include <cassert>
16a825fb2cSChristian Sigg #include <numeric>
17a825fb2cSChristian Sigg
18a825fb2cSChristian Sigg #include "mlir/ExecutionEngine/CRunnerUtils.h"
19a825fb2cSChristian Sigg #include "llvm/ADT/ArrayRef.h"
20a825fb2cSChristian Sigg
21a825fb2cSChristian Sigg #include "hip/hip_runtime.h"
22a825fb2cSChristian Sigg
23a825fb2cSChristian Sigg #define HIP_REPORT_IF_ERROR(expr) \
24a825fb2cSChristian Sigg [](hipError_t result) { \
25a825fb2cSChristian Sigg if (!result) \
26a825fb2cSChristian Sigg return; \
27a825fb2cSChristian Sigg const char *name = hipGetErrorName(result); \
28a825fb2cSChristian Sigg if (!name) \
29a825fb2cSChristian Sigg name = "<unknown>"; \
30a825fb2cSChristian Sigg fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \
31a825fb2cSChristian Sigg }(expr)
32a825fb2cSChristian Sigg
33*84718d37SKrzysztof Drewniak thread_local static int32_t defaultDevice = 0;
34*84718d37SKrzysztof Drewniak
35a825fb2cSChristian Sigg // Sets the `Context` for the duration of the instance and restores the previous
36a825fb2cSChristian Sigg // context on destruction.
37a825fb2cSChristian Sigg class ScopedContext {
38a825fb2cSChristian Sigg public:
ScopedContext()39a825fb2cSChristian Sigg ScopedContext() {
40*84718d37SKrzysztof Drewniak // Static reference to HIP primary context for device ordinal defaultDevice.
41a825fb2cSChristian Sigg static hipCtx_t context = [] {
42a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipInit(/*flags=*/0));
43a825fb2cSChristian Sigg hipDevice_t device;
44*84718d37SKrzysztof Drewniak HIP_REPORT_IF_ERROR(hipDeviceGet(&device, /*ordinal=*/defaultDevice));
45a825fb2cSChristian Sigg hipCtx_t ctx;
46a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipDevicePrimaryCtxRetain(&ctx, device));
47a825fb2cSChristian Sigg return ctx;
48a825fb2cSChristian Sigg }();
49a825fb2cSChristian Sigg
50a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipCtxPushCurrent(context));
51a825fb2cSChristian Sigg }
52a825fb2cSChristian Sigg
~ScopedContext()53a825fb2cSChristian Sigg ~ScopedContext() { HIP_REPORT_IF_ERROR(hipCtxPopCurrent(nullptr)); }
54a825fb2cSChristian Sigg };
55a825fb2cSChristian Sigg
mgpuModuleLoad(void * data)56a825fb2cSChristian Sigg extern "C" hipModule_t mgpuModuleLoad(void *data) {
57a825fb2cSChristian Sigg ScopedContext scopedContext;
58a825fb2cSChristian Sigg hipModule_t module = nullptr;
59a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipModuleLoadData(&module, data));
60a825fb2cSChristian Sigg return module;
61a825fb2cSChristian Sigg }
62a825fb2cSChristian Sigg
mgpuModuleUnload(hipModule_t module)63a825fb2cSChristian Sigg extern "C" void mgpuModuleUnload(hipModule_t module) {
64a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipModuleUnload(module));
65a825fb2cSChristian Sigg }
66a825fb2cSChristian Sigg
mgpuModuleGetFunction(hipModule_t module,const char * name)67a825fb2cSChristian Sigg extern "C" hipFunction_t mgpuModuleGetFunction(hipModule_t module,
68a825fb2cSChristian Sigg const char *name) {
69a825fb2cSChristian Sigg hipFunction_t function = nullptr;
70a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipModuleGetFunction(&function, module, name));
71a825fb2cSChristian Sigg return function;
72a825fb2cSChristian Sigg }
73a825fb2cSChristian Sigg
74a825fb2cSChristian Sigg // The wrapper uses intptr_t instead of ROCM's unsigned int to match
75a825fb2cSChristian Sigg // the type of MLIR's index type. This avoids the need for casts in the
76a825fb2cSChristian Sigg // generated MLIR code.
mgpuLaunchKernel(hipFunction_t function,intptr_t gridX,intptr_t gridY,intptr_t gridZ,intptr_t blockX,intptr_t blockY,intptr_t blockZ,int32_t smem,hipStream_t stream,void ** params,void ** extra)77a825fb2cSChristian Sigg extern "C" void mgpuLaunchKernel(hipFunction_t function, intptr_t gridX,
78a825fb2cSChristian Sigg intptr_t gridY, intptr_t gridZ,
79a825fb2cSChristian Sigg intptr_t blockX, intptr_t blockY,
80a825fb2cSChristian Sigg intptr_t blockZ, int32_t smem,
81a825fb2cSChristian Sigg hipStream_t stream, void **params,
82a825fb2cSChristian Sigg void **extra) {
83a825fb2cSChristian Sigg ScopedContext scopedContext;
84a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ,
85a825fb2cSChristian Sigg blockX, blockY, blockZ, smem,
86a825fb2cSChristian Sigg stream, params, extra));
87a825fb2cSChristian Sigg }
88a825fb2cSChristian Sigg
mgpuStreamCreate()89a825fb2cSChristian Sigg extern "C" hipStream_t mgpuStreamCreate() {
90a825fb2cSChristian Sigg ScopedContext scopedContext;
91a825fb2cSChristian Sigg hipStream_t stream = nullptr;
92a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipStreamCreate(&stream));
93a825fb2cSChristian Sigg return stream;
94a825fb2cSChristian Sigg }
95a825fb2cSChristian Sigg
mgpuStreamDestroy(hipStream_t stream)96a825fb2cSChristian Sigg extern "C" void mgpuStreamDestroy(hipStream_t stream) {
97a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipStreamDestroy(stream));
98a825fb2cSChristian Sigg }
99a825fb2cSChristian Sigg
mgpuStreamSynchronize(hipStream_t stream)100a825fb2cSChristian Sigg extern "C" void mgpuStreamSynchronize(hipStream_t stream) {
101a825fb2cSChristian Sigg return HIP_REPORT_IF_ERROR(hipStreamSynchronize(stream));
102a825fb2cSChristian Sigg }
103a825fb2cSChristian Sigg
mgpuStreamWaitEvent(hipStream_t stream,hipEvent_t event)104a825fb2cSChristian Sigg extern "C" void mgpuStreamWaitEvent(hipStream_t stream, hipEvent_t event) {
105a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipStreamWaitEvent(stream, event, /*flags=*/0));
106a825fb2cSChristian Sigg }
107a825fb2cSChristian Sigg
mgpuEventCreate()108a825fb2cSChristian Sigg extern "C" hipEvent_t mgpuEventCreate() {
109a825fb2cSChristian Sigg ScopedContext scopedContext;
110a825fb2cSChristian Sigg hipEvent_t event = nullptr;
111a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipEventCreateWithFlags(&event, hipEventDisableTiming));
112a825fb2cSChristian Sigg return event;
113a825fb2cSChristian Sigg }
114a825fb2cSChristian Sigg
mgpuEventDestroy(hipEvent_t event)115a825fb2cSChristian Sigg extern "C" void mgpuEventDestroy(hipEvent_t event) {
116a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipEventDestroy(event));
117a825fb2cSChristian Sigg }
118a825fb2cSChristian Sigg
mgpuEventSynchronize(hipEvent_t event)119a825fb2cSChristian Sigg extern "C" void mgpuEventSynchronize(hipEvent_t event) {
120a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipEventSynchronize(event));
121a825fb2cSChristian Sigg }
122a825fb2cSChristian Sigg
mgpuEventRecord(hipEvent_t event,hipStream_t stream)123a825fb2cSChristian Sigg extern "C" void mgpuEventRecord(hipEvent_t event, hipStream_t stream) {
124a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipEventRecord(event, stream));
125a825fb2cSChristian Sigg }
126a825fb2cSChristian Sigg
mgpuMemAlloc(uint64_t sizeBytes,hipStream_t)127a825fb2cSChristian Sigg extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, hipStream_t /*stream*/) {
128a825fb2cSChristian Sigg ScopedContext scopedContext;
129a825fb2cSChristian Sigg void *ptr;
130a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipMalloc(&ptr, sizeBytes));
131a825fb2cSChristian Sigg return ptr;
132a825fb2cSChristian Sigg }
133a825fb2cSChristian Sigg
mgpuMemFree(void * ptr,hipStream_t)134a825fb2cSChristian Sigg extern "C" void mgpuMemFree(void *ptr, hipStream_t /*stream*/) {
135a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipFree(ptr));
136a825fb2cSChristian Sigg }
137a825fb2cSChristian Sigg
mgpuMemcpy(void * dst,void * src,size_t sizeBytes,hipStream_t stream)138361458b1SLoren Maggiore extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,
139a825fb2cSChristian Sigg hipStream_t stream) {
140a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(
141a825fb2cSChristian Sigg hipMemcpyAsync(dst, src, sizeBytes, hipMemcpyDefault, stream));
142a825fb2cSChristian Sigg }
143a825fb2cSChristian Sigg
mgpuMemset32(void * dst,int value,size_t count,hipStream_t stream)144361458b1SLoren Maggiore extern "C" void mgpuMemset32(void *dst, int value, size_t count,
145361458b1SLoren Maggiore hipStream_t stream) {
146361458b1SLoren Maggiore HIP_REPORT_IF_ERROR(hipMemsetD32Async(reinterpret_cast<hipDeviceptr_t>(dst),
147361458b1SLoren Maggiore value, count, stream));
148361458b1SLoren Maggiore }
149a825fb2cSChristian Sigg /// Helper functions for writing mlir example code
150a825fb2cSChristian Sigg
151a825fb2cSChristian Sigg // Allows to register byte array with the ROCM runtime. Helpful until we have
152a825fb2cSChristian Sigg // transfer functions implemented.
mgpuMemHostRegister(void * ptr,uint64_t sizeBytes)153a825fb2cSChristian Sigg extern "C" void mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
154a825fb2cSChristian Sigg ScopedContext scopedContext;
155a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipHostRegister(ptr, sizeBytes, /*flags=*/0));
156a825fb2cSChristian Sigg }
157a825fb2cSChristian Sigg
158a825fb2cSChristian Sigg // Allows to register a MemRef with the ROCm runtime. Helpful until we have
159a825fb2cSChristian Sigg // transfer functions implemented.
160a825fb2cSChristian Sigg extern "C" void
mgpuMemHostRegisterMemRef(int64_t rank,StridedMemRefType<char,1> * descriptor,int64_t elementSizeBytes)161a825fb2cSChristian Sigg mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
162a825fb2cSChristian Sigg int64_t elementSizeBytes) {
163a825fb2cSChristian Sigg
164a825fb2cSChristian Sigg llvm::SmallVector<int64_t, 4> denseStrides(rank);
165a825fb2cSChristian Sigg llvm::ArrayRef<int64_t> sizes(descriptor->sizes, rank);
166a825fb2cSChristian Sigg llvm::ArrayRef<int64_t> strides(sizes.end(), rank);
167a825fb2cSChristian Sigg
168a825fb2cSChristian Sigg std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(),
169a825fb2cSChristian Sigg std::multiplies<int64_t>());
170a825fb2cSChristian Sigg auto sizeBytes = denseStrides.front() * elementSizeBytes;
171a825fb2cSChristian Sigg
172a825fb2cSChristian Sigg // Only densely packed tensors are currently supported.
173a825fb2cSChristian Sigg std::rotate(denseStrides.begin(), denseStrides.begin() + 1,
174a825fb2cSChristian Sigg denseStrides.end());
175a825fb2cSChristian Sigg denseStrides.back() = 1;
176a825fb2cSChristian Sigg assert(strides == llvm::makeArrayRef(denseStrides));
177a825fb2cSChristian Sigg
178a825fb2cSChristian Sigg auto ptr = descriptor->data + descriptor->offset * elementSizeBytes;
179a825fb2cSChristian Sigg mgpuMemHostRegister(ptr, sizeBytes);
180a825fb2cSChristian Sigg }
181a825fb2cSChristian Sigg
182a825fb2cSChristian Sigg template <typename T>
mgpuMemGetDevicePointer(T * hostPtr,T ** devicePtr)183a825fb2cSChristian Sigg void mgpuMemGetDevicePointer(T *hostPtr, T **devicePtr) {
184a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(hipSetDevice(0));
185a825fb2cSChristian Sigg HIP_REPORT_IF_ERROR(
186a825fb2cSChristian Sigg hipHostGetDevicePointer((void **)devicePtr, hostPtr, /*flags=*/0));
187a825fb2cSChristian Sigg }
188a825fb2cSChristian Sigg
189a825fb2cSChristian Sigg extern "C" StridedMemRefType<float, 1>
mgpuMemGetDeviceMemRef1dFloat(float * allocated,float * aligned,int64_t offset,int64_t size,int64_t stride)190a825fb2cSChristian Sigg mgpuMemGetDeviceMemRef1dFloat(float *allocated, float *aligned, int64_t offset,
191a825fb2cSChristian Sigg int64_t size, int64_t stride) {
192a825fb2cSChristian Sigg float *devicePtr = nullptr;
193a825fb2cSChristian Sigg mgpuMemGetDevicePointer(aligned, &devicePtr);
194a825fb2cSChristian Sigg return {devicePtr, devicePtr, offset, {size}, {stride}};
195a825fb2cSChristian Sigg }
196a825fb2cSChristian Sigg
197a825fb2cSChristian Sigg extern "C" StridedMemRefType<int32_t, 1>
mgpuMemGetDeviceMemRef1dInt32(int32_t * allocated,int32_t * aligned,int64_t offset,int64_t size,int64_t stride)198a825fb2cSChristian Sigg mgpuMemGetDeviceMemRef1dInt32(int32_t *allocated, int32_t *aligned,
199a825fb2cSChristian Sigg int64_t offset, int64_t size, int64_t stride) {
200a825fb2cSChristian Sigg int32_t *devicePtr = nullptr;
201a825fb2cSChristian Sigg mgpuMemGetDevicePointer(aligned, &devicePtr);
202a825fb2cSChristian Sigg return {devicePtr, devicePtr, offset, {size}, {stride}};
203a825fb2cSChristian Sigg }
204*84718d37SKrzysztof Drewniak
mgpuSetDefaultDevice(int32_t device)205*84718d37SKrzysztof Drewniak extern "C" void mgpuSetDefaultDevice(int32_t device) {
206*84718d37SKrzysztof Drewniak defaultDevice = device;
207*84718d37SKrzysztof Drewniak HIP_REPORT_IF_ERROR(hipSetDevice(device));
208*84718d37SKrzysztof Drewniak }
209