1 //===- CudaRuntimeWrappers.cpp - MLIR CUDA API wrapper library ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Implements C wrappers around the CUDA library for easy linking in ORC jit. 10 // Also adds some debugging helpers that are helpful when writing MLIR code to 11 // run on GPUs. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include <cassert> 16 #include <numeric> 17 18 #include "mlir/ExecutionEngine/CRunnerUtils.h" 19 #include "llvm/ADT/ArrayRef.h" 20 21 #include "cuda.h" 22 23 #ifdef _WIN32 24 #define MLIR_CUDA_WRAPPERS_EXPORT __declspec(dllexport) 25 #else 26 #define MLIR_CUDA_WRAPPERS_EXPORT 27 #endif // _WIN32 28 29 #define CUDA_REPORT_IF_ERROR(expr) \ 30 [](CUresult result) { \ 31 if (!result) \ 32 return; \ 33 const char *name = nullptr; \ 34 cuGetErrorName(result, &name); \ 35 if (!name) \ 36 name = "<unknown>"; \ 37 fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ 38 }(expr) 39 40 // Make the primary context of device 0 current for the duration of the instance 41 // and restore the previous context on destruction. 42 class ScopedContext { 43 public: 44 ScopedContext() { 45 // Static reference to CUDA primary context for device ordinal 0. 46 static CUcontext context = [] { 47 CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0)); 48 CUdevice device; 49 CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0)); 50 CUcontext ctx; 51 // Note: this does not affect the current context. 52 CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&ctx, device)); 53 return ctx; 54 }(); 55 56 CUDA_REPORT_IF_ERROR(cuCtxPushCurrent(context)); 57 } 58 59 ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxPopCurrent(nullptr)); } 60 }; 61 62 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoad(void *data) { 63 ScopedContext scopedContext; 64 CUmodule module = nullptr; 65 CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data)); 66 return module; 67 } 68 69 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuModuleUnload(CUmodule module) { 70 CUDA_REPORT_IF_ERROR(cuModuleUnload(module)); 71 } 72 73 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUfunction 74 mgpuModuleGetFunction(CUmodule module, const char *name) { 75 CUfunction function = nullptr; 76 CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name)); 77 return function; 78 } 79 80 // The wrapper uses intptr_t instead of CUDA's unsigned int to match 81 // the type of MLIR's index type. This avoids the need for casts in the 82 // generated MLIR code. 83 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 84 mgpuLaunchKernel(CUfunction function, intptr_t gridX, intptr_t gridY, 85 intptr_t gridZ, intptr_t blockX, intptr_t blockY, 86 intptr_t blockZ, int32_t smem, CUstream stream, void **params, 87 void **extra) { 88 ScopedContext scopedContext; 89 CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX, 90 blockY, blockZ, smem, stream, params, 91 extra)); 92 } 93 94 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUstream mgpuStreamCreate() { 95 ScopedContext scopedContext; 96 CUstream stream = nullptr; 97 CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); 98 return stream; 99 } 100 101 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamDestroy(CUstream stream) { 102 CUDA_REPORT_IF_ERROR(cuStreamDestroy(stream)); 103 } 104 105 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 106 mgpuStreamSynchronize(CUstream stream) { 107 CUDA_REPORT_IF_ERROR(cuStreamSynchronize(stream)); 108 } 109 110 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamWaitEvent(CUstream stream, 111 CUevent event) { 112 CUDA_REPORT_IF_ERROR(cuStreamWaitEvent(stream, event, /*flags=*/0)); 113 } 114 115 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUevent mgpuEventCreate() { 116 ScopedContext scopedContext; 117 CUevent event = nullptr; 118 CUDA_REPORT_IF_ERROR(cuEventCreate(&event, CU_EVENT_DISABLE_TIMING)); 119 return event; 120 } 121 122 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventDestroy(CUevent event) { 123 CUDA_REPORT_IF_ERROR(cuEventDestroy(event)); 124 } 125 126 extern MLIR_CUDA_WRAPPERS_EXPORT "C" void mgpuEventSynchronize(CUevent event) { 127 CUDA_REPORT_IF_ERROR(cuEventSynchronize(event)); 128 } 129 130 extern MLIR_CUDA_WRAPPERS_EXPORT "C" void mgpuEventRecord(CUevent event, 131 CUstream stream) { 132 CUDA_REPORT_IF_ERROR(cuEventRecord(event, stream)); 133 } 134 135 extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, CUstream /*stream*/) { 136 ScopedContext scopedContext; 137 CUdeviceptr ptr; 138 CUDA_REPORT_IF_ERROR(cuMemAlloc(&ptr, sizeBytes)); 139 return reinterpret_cast<void *>(ptr); 140 } 141 142 extern "C" void mgpuMemFree(void *ptr, CUstream /*stream*/) { 143 CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(ptr))); 144 } 145 146 extern "C" void mgpuMemcpy(void *dst, void *src, uint64_t sizeBytes, 147 CUstream stream) { 148 CUDA_REPORT_IF_ERROR(cuMemcpyAsync(reinterpret_cast<CUdeviceptr>(dst), 149 reinterpret_cast<CUdeviceptr>(src), 150 sizeBytes, stream)); 151 } 152 153 /// Helper functions for writing mlir example code 154 155 // Allows to register byte array with the CUDA runtime. Helpful until we have 156 // transfer functions implemented. 157 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 158 mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) { 159 ScopedContext scopedContext; 160 CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0)); 161 } 162 163 // Allows to register a MemRef with the CUDA runtime. Helpful until we have 164 // transfer functions implemented. 165 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 166 mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor, 167 int64_t elementSizeBytes) { 168 169 llvm::SmallVector<int64_t, 4> denseStrides(rank); 170 llvm::ArrayRef<int64_t> sizes(descriptor->sizes, rank); 171 llvm::ArrayRef<int64_t> strides(sizes.end(), rank); 172 173 std::partial_sum(sizes.rbegin(), sizes.rend(), denseStrides.rbegin(), 174 std::multiplies<int64_t>()); 175 auto sizeBytes = denseStrides.front() * elementSizeBytes; 176 177 // Only densely packed tensors are currently supported. 178 std::rotate(denseStrides.begin(), denseStrides.begin() + 1, 179 denseStrides.end()); 180 denseStrides.back() = 1; 181 assert(strides == llvm::makeArrayRef(denseStrides)); 182 183 auto ptr = descriptor->data + descriptor->offset * elementSizeBytes; 184 mgpuMemHostRegister(ptr, sizeBytes); 185 } 186