1 //===- CudaRuntimeWrappers.cpp - MLIR CUDA API wrapper library ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Implements C wrappers around the CUDA library for easy linking in ORC jit. 10 // Also adds some debugging helpers that are helpful when writing MLIR code to 11 // run on GPUs. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include <cassert> 16 #include <numeric> 17 18 #include "mlir/ExecutionEngine/CRunnerUtils.h" 19 #include "llvm/ADT/ArrayRef.h" 20 21 #include "cuda.h" 22 23 #ifdef _WIN32 24 #define MLIR_CUDA_WRAPPERS_EXPORT __declspec(dllexport) 25 #else 26 #define MLIR_CUDA_WRAPPERS_EXPORT 27 #endif // _WIN32 28 29 #define CUDA_REPORT_IF_ERROR(expr) \ 30 [](CUresult result) { \ 31 if (!result) \ 32 return; \ 33 const char *name = nullptr; \ 34 cuGetErrorName(result, &name); \ 35 if (!name) \ 36 name = "<unknown>"; \ 37 fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ 38 }(expr) 39 40 // Static reference to CUDA primary context for device ordinal 0. 41 static CUcontext Context = [] { 42 CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0)); 43 CUdevice device; 44 CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0)); 45 CUcontext context; 46 CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&context, device)); 47 return context; 48 }(); 49 50 // Sets the `Context` for the duration of the instance and restores the previous 51 // context on destruction. 52 class ScopedContext { 53 public: 54 ScopedContext() { 55 CUDA_REPORT_IF_ERROR(cuCtxGetCurrent(&previous)); 56 CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(Context)); 57 } 58 59 ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxSetCurrent(previous)); } 60 61 private: 62 CUcontext previous; 63 }; 64 65 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoad(void *data) { 66 ScopedContext scopedContext; 67 CUmodule module = nullptr; 68 CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data)); 69 return module; 70 } 71 72 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuModuleUnload(CUmodule module) { 73 CUDA_REPORT_IF_ERROR(cuModuleUnload(module)); 74 } 75 76 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUfunction 77 mgpuModuleGetFunction(CUmodule module, const char *name) { 78 CUfunction function = nullptr; 79 CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name)); 80 return function; 81 } 82 83 // The wrapper uses intptr_t instead of CUDA's unsigned int to match 84 // the type of MLIR's index type. This avoids the need for casts in the 85 // generated MLIR code. 86 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 87 mgpuLaunchKernel(CUfunction function, intptr_t gridX, intptr_t gridY, 88 intptr_t gridZ, intptr_t blockX, intptr_t blockY, 89 intptr_t blockZ, int32_t smem, CUstream stream, void **params, 90 void **extra) { 91 ScopedContext scopedContext; 92 CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX, 93 blockY, blockZ, smem, stream, params, 94 extra)); 95 } 96 97 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUstream mgpuStreamCreate() { 98 ScopedContext scopedContext; 99 CUstream stream = nullptr; 100 CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); 101 return stream; 102 } 103 104 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamDestroy(CUstream stream) { 105 CUDA_REPORT_IF_ERROR(cuStreamDestroy(stream)); 106 } 107 108 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 109 mgpuStreamSynchronize(CUstream stream) { 110 CUDA_REPORT_IF_ERROR(cuStreamSynchronize(stream)); 111 } 112 113 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamWaitEvent(CUstream stream, 114 CUevent event) { 115 CUDA_REPORT_IF_ERROR(cuStreamWaitEvent(stream, event, /*flags=*/0)); 116 } 117 118 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUevent mgpuEventCreate() { 119 ScopedContext scopedContext; 120 CUevent event = nullptr; 121 CUDA_REPORT_IF_ERROR(cuEventCreate(&event, CU_EVENT_DISABLE_TIMING)); 122 return event; 123 } 124 125 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventDestroy(CUevent event) { 126 CUDA_REPORT_IF_ERROR(cuEventDestroy(event)); 127 } 128 129 extern MLIR_CUDA_WRAPPERS_EXPORT "C" void mgpuEventSynchronize(CUevent event) { 130 CUDA_REPORT_IF_ERROR(cuEventSynchronize(event)); 131 } 132 133 extern MLIR_CUDA_WRAPPERS_EXPORT "C" void mgpuEventRecord(CUevent event, 134 CUstream stream) { 135 CUDA_REPORT_IF_ERROR(cuEventRecord(event, stream)); 136 } 137 138 extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, CUstream /*stream*/) { 139 ScopedContext scopedContext; 140 CUdeviceptr ptr; 141 CUDA_REPORT_IF_ERROR(cuMemAlloc(&ptr, sizeBytes)); 142 return reinterpret_cast<void *>(ptr); 143 } 144 145 extern "C" void mgpuMemFree(void *ptr, CUstream /*stream*/) { 146 CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(ptr))); 147 } 148 149 extern "C" void mgpuMemcpy(void *dst, void *src, uint64_t sizeBytes, 150 CUstream stream) { 151 CUDA_REPORT_IF_ERROR(cuMemcpyAsync(reinterpret_cast<CUdeviceptr>(dst), 152 reinterpret_cast<CUdeviceptr>(src), 153 sizeBytes, stream)); 154 } 155 156 /// Helper functions for writing mlir example code 157 158 // Allows to register byte array with the CUDA runtime. Helpful until we have 159 // transfer functions implemented. 160 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 161 mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) { 162 ScopedContext scopedContext; 163 CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0)); 164 } 165 166 // Allows to register a MemRef with the CUDA runtime. Helpful until we have 167 // transfer functions implemented. 168 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 169 mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor, 170 int64_t elementSizeBytes) { 171 172 llvm::SmallVector<int64_t, 4> denseStrides(rank); 173 llvm::ArrayRef<int64_t> sizes(descriptor->sizes, rank); 174 llvm::ArrayRef<int64_t> strides(sizes.end(), rank); 175 176 std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(), 177 std::multiplies<int64_t>()); 178 auto sizeBytes = denseStrides.front() * elementSizeBytes; 179 180 // Only densely packed tensors are currently supported. 181 std::rotate(denseStrides.begin(), denseStrides.begin() + 1, 182 denseStrides.end()); 183 denseStrides.back() = 1; 184 assert(strides == llvm::makeArrayRef(denseStrides)); 185 186 auto ptr = descriptor->data + descriptor->offset * elementSizeBytes; 187 mgpuMemHostRegister(ptr, sizeBytes); 188 } 189