1 //===- CudaRuntimeWrappers.cpp - MLIR CUDA API wrapper library ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Implements C wrappers around the CUDA library for easy linking in ORC jit. 10 // Also adds some debugging helpers that are helpful when writing MLIR code to 11 // run on GPUs. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/ExecutionEngine/CRunnerUtils.h" 16 17 #include "cuda.h" 18 19 #ifdef _WIN32 20 #define MLIR_CUDA_WRAPPERS_EXPORT __declspec(dllexport) 21 #else 22 #define MLIR_CUDA_WRAPPERS_EXPORT 23 #endif // _WIN32 24 25 #define CUDA_REPORT_IF_ERROR(expr) \ 26 [](CUresult result) { \ 27 if (!result) \ 28 return; \ 29 const char *name = nullptr; \ 30 cuGetErrorName(result, &name); \ 31 if (!name) \ 32 name = "<unknown>"; \ 33 fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ 34 }(expr) 35 36 // Make the primary context of device 0 current for the duration of the instance 37 // and restore the previous context on destruction. 38 class ScopedContext { 39 public: 40 ScopedContext() { 41 // Static reference to CUDA primary context for device ordinal 0. 42 static CUcontext context = [] { 43 CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0)); 44 CUdevice device; 45 CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0)); 46 CUcontext ctx; 47 // Note: this does not affect the current context. 48 CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&ctx, device)); 49 return ctx; 50 }(); 51 52 CUDA_REPORT_IF_ERROR(cuCtxPushCurrent(context)); 53 } 54 55 ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxPopCurrent(nullptr)); } 56 }; 57 58 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoad(void *data) { 59 ScopedContext scopedContext; 60 CUmodule module = nullptr; 61 CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data)); 62 return module; 63 } 64 65 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuModuleUnload(CUmodule module) { 66 CUDA_REPORT_IF_ERROR(cuModuleUnload(module)); 67 } 68 69 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUfunction 70 mgpuModuleGetFunction(CUmodule module, const char *name) { 71 CUfunction function = nullptr; 72 CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name)); 73 return function; 74 } 75 76 // The wrapper uses intptr_t instead of CUDA's unsigned int to match 77 // the type of MLIR's index type. This avoids the need for casts in the 78 // generated MLIR code. 79 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 80 mgpuLaunchKernel(CUfunction function, intptr_t gridX, intptr_t gridY, 81 intptr_t gridZ, intptr_t blockX, intptr_t blockY, 82 intptr_t blockZ, int32_t smem, CUstream stream, void **params, 83 void **extra) { 84 ScopedContext scopedContext; 85 CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX, 86 blockY, blockZ, smem, stream, params, 87 extra)); 88 } 89 90 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUstream mgpuStreamCreate() { 91 ScopedContext scopedContext; 92 CUstream stream = nullptr; 93 CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); 94 return stream; 95 } 96 97 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamDestroy(CUstream stream) { 98 CUDA_REPORT_IF_ERROR(cuStreamDestroy(stream)); 99 } 100 101 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 102 mgpuStreamSynchronize(CUstream stream) { 103 CUDA_REPORT_IF_ERROR(cuStreamSynchronize(stream)); 104 } 105 106 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamWaitEvent(CUstream stream, 107 CUevent event) { 108 CUDA_REPORT_IF_ERROR(cuStreamWaitEvent(stream, event, /*flags=*/0)); 109 } 110 111 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUevent mgpuEventCreate() { 112 ScopedContext scopedContext; 113 CUevent event = nullptr; 114 CUDA_REPORT_IF_ERROR(cuEventCreate(&event, CU_EVENT_DISABLE_TIMING)); 115 return event; 116 } 117 118 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventDestroy(CUevent event) { 119 CUDA_REPORT_IF_ERROR(cuEventDestroy(event)); 120 } 121 122 extern MLIR_CUDA_WRAPPERS_EXPORT "C" void mgpuEventSynchronize(CUevent event) { 123 CUDA_REPORT_IF_ERROR(cuEventSynchronize(event)); 124 } 125 126 extern MLIR_CUDA_WRAPPERS_EXPORT "C" void mgpuEventRecord(CUevent event, 127 CUstream stream) { 128 CUDA_REPORT_IF_ERROR(cuEventRecord(event, stream)); 129 } 130 131 extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, CUstream /*stream*/) { 132 ScopedContext scopedContext; 133 CUdeviceptr ptr; 134 CUDA_REPORT_IF_ERROR(cuMemAlloc(&ptr, sizeBytes)); 135 return reinterpret_cast<void *>(ptr); 136 } 137 138 extern "C" void mgpuMemFree(void *ptr, CUstream /*stream*/) { 139 CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(ptr))); 140 } 141 142 extern "C" void mgpuMemcpy(void *dst, void *src, uint64_t sizeBytes, 143 CUstream stream) { 144 CUDA_REPORT_IF_ERROR(cuMemcpyAsync(reinterpret_cast<CUdeviceptr>(dst), 145 reinterpret_cast<CUdeviceptr>(src), 146 sizeBytes, stream)); 147 } 148 149 /// Helper functions for writing mlir example code 150 151 // Allows to register byte array with the CUDA runtime. Helpful until we have 152 // transfer functions implemented. 153 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 154 mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) { 155 ScopedContext scopedContext; 156 CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0)); 157 } 158 159 /// Registers a memref with the CUDA runtime. `descriptor` is a pointer to a 160 /// ranked memref descriptor struct of rank `rank`. Helpful until we have 161 /// transfer functions implemented. 162 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 163 mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor, 164 int64_t elementSizeBytes) { 165 // Only densely packed tensors are currently supported. 166 int64_t *denseStrides = (int64_t *)alloca(rank * sizeof(int64_t)); 167 int64_t *sizes = descriptor->sizes; 168 for (int64_t i = rank - 1, runningStride = 1; i >= 0; i--) { 169 denseStrides[i] = runningStride; 170 runningStride *= sizes[i]; 171 } 172 uint64_t sizeBytes = sizes[0] * denseStrides[0] * elementSizeBytes; 173 int64_t *strides = &sizes[rank]; 174 for (unsigned i = 0; i < rank; ++i) 175 assert(strides[i] == denseStrides[i] && 176 "Mismatch in computed dense strides"); 177 178 auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes; 179 mgpuMemHostRegister(ptr, sizeBytes); 180 } 181