1 //===- CudaRuntimeWrappers.cpp - MLIR CUDA API wrapper library ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Implements C wrappers around the CUDA library for easy linking in ORC jit. 10 // Also adds some debugging helpers that are helpful when writing MLIR code to 11 // run on GPUs. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "mlir/ExecutionEngine/CRunnerUtils.h" 16 17 #include <stdio.h> 18 19 #include "cuda.h" 20 21 #ifdef _WIN32 22 #define MLIR_CUDA_WRAPPERS_EXPORT __declspec(dllexport) 23 #else 24 #define MLIR_CUDA_WRAPPERS_EXPORT 25 #endif // _WIN32 26 27 #define CUDA_REPORT_IF_ERROR(expr) \ 28 [](CUresult result) { \ 29 if (!result) \ 30 return; \ 31 const char *name = nullptr; \ 32 cuGetErrorName(result, &name); \ 33 if (!name) \ 34 name = "<unknown>"; \ 35 fprintf(stderr, "'%s' failed with '%s'\n", #expr, name); \ 36 }(expr) 37 38 // Make the primary context of device 0 current for the duration of the instance 39 // and restore the previous context on destruction. 40 class ScopedContext { 41 public: 42 ScopedContext() { 43 // Static reference to CUDA primary context for device ordinal 0. 44 static CUcontext context = [] { 45 CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0)); 46 CUdevice device; 47 CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/0)); 48 CUcontext ctx; 49 // Note: this does not affect the current context. 50 CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&ctx, device)); 51 return ctx; 52 }(); 53 54 CUDA_REPORT_IF_ERROR(cuCtxPushCurrent(context)); 55 } 56 57 ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxPopCurrent(nullptr)); } 58 }; 59 60 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoad(void *data) { 61 ScopedContext scopedContext; 62 CUmodule module = nullptr; 63 CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data)); 64 return module; 65 } 66 67 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuModuleUnload(CUmodule module) { 68 CUDA_REPORT_IF_ERROR(cuModuleUnload(module)); 69 } 70 71 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUfunction 72 mgpuModuleGetFunction(CUmodule module, const char *name) { 73 CUfunction function = nullptr; 74 CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name)); 75 return function; 76 } 77 78 // The wrapper uses intptr_t instead of CUDA's unsigned int to match 79 // the type of MLIR's index type. This avoids the need for casts in the 80 // generated MLIR code. 81 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 82 mgpuLaunchKernel(CUfunction function, intptr_t gridX, intptr_t gridY, 83 intptr_t gridZ, intptr_t blockX, intptr_t blockY, 84 intptr_t blockZ, int32_t smem, CUstream stream, void **params, 85 void **extra) { 86 ScopedContext scopedContext; 87 CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX, 88 blockY, blockZ, smem, stream, params, 89 extra)); 90 } 91 92 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUstream mgpuStreamCreate() { 93 ScopedContext scopedContext; 94 CUstream stream = nullptr; 95 CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING)); 96 return stream; 97 } 98 99 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamDestroy(CUstream stream) { 100 CUDA_REPORT_IF_ERROR(cuStreamDestroy(stream)); 101 } 102 103 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 104 mgpuStreamSynchronize(CUstream stream) { 105 CUDA_REPORT_IF_ERROR(cuStreamSynchronize(stream)); 106 } 107 108 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamWaitEvent(CUstream stream, 109 CUevent event) { 110 CUDA_REPORT_IF_ERROR(cuStreamWaitEvent(stream, event, /*flags=*/0)); 111 } 112 113 extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUevent mgpuEventCreate() { 114 ScopedContext scopedContext; 115 CUevent event = nullptr; 116 CUDA_REPORT_IF_ERROR(cuEventCreate(&event, CU_EVENT_DISABLE_TIMING)); 117 return event; 118 } 119 120 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventDestroy(CUevent event) { 121 CUDA_REPORT_IF_ERROR(cuEventDestroy(event)); 122 } 123 124 extern MLIR_CUDA_WRAPPERS_EXPORT "C" void mgpuEventSynchronize(CUevent event) { 125 CUDA_REPORT_IF_ERROR(cuEventSynchronize(event)); 126 } 127 128 extern MLIR_CUDA_WRAPPERS_EXPORT "C" void mgpuEventRecord(CUevent event, 129 CUstream stream) { 130 CUDA_REPORT_IF_ERROR(cuEventRecord(event, stream)); 131 } 132 133 extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, CUstream /*stream*/) { 134 ScopedContext scopedContext; 135 CUdeviceptr ptr; 136 CUDA_REPORT_IF_ERROR(cuMemAlloc(&ptr, sizeBytes)); 137 return reinterpret_cast<void *>(ptr); 138 } 139 140 extern "C" void mgpuMemFree(void *ptr, CUstream /*stream*/) { 141 CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(ptr))); 142 } 143 144 extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes, 145 CUstream stream) { 146 CUDA_REPORT_IF_ERROR(cuMemcpyAsync(reinterpret_cast<CUdeviceptr>(dst), 147 reinterpret_cast<CUdeviceptr>(src), 148 sizeBytes, stream)); 149 } 150 151 extern "C" void mgpuMemset32(void *dst, unsigned int value, size_t count, 152 CUstream stream) { 153 CUDA_REPORT_IF_ERROR(cuMemsetD32Async(reinterpret_cast<CUdeviceptr>(dst), 154 value, count, stream)); 155 } 156 157 /// Helper functions for writing mlir example code 158 159 // Allows to register byte array with the CUDA runtime. Helpful until we have 160 // transfer functions implemented. 161 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 162 mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) { 163 ScopedContext scopedContext; 164 CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0)); 165 } 166 167 /// Registers a memref with the CUDA runtime. `descriptor` is a pointer to a 168 /// ranked memref descriptor struct of rank `rank`. Helpful until we have 169 /// transfer functions implemented. 170 extern "C" MLIR_CUDA_WRAPPERS_EXPORT void 171 mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor, 172 int64_t elementSizeBytes) { 173 // Only densely packed tensors are currently supported. 174 int64_t *denseStrides = (int64_t *)alloca(rank * sizeof(int64_t)); 175 int64_t *sizes = descriptor->sizes; 176 for (int64_t i = rank - 1, runningStride = 1; i >= 0; i--) { 177 denseStrides[i] = runningStride; 178 runningStride *= sizes[i]; 179 } 180 uint64_t sizeBytes = sizes[0] * denseStrides[0] * elementSizeBytes; 181 int64_t *strides = &sizes[rank]; 182 (void)strides; 183 for (unsigned i = 0; i < rank; ++i) 184 assert(strides[i] == denseStrides[i] && 185 "Mismatch in computed dense strides"); 186 187 auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes; 188 mgpuMemHostRegister(ptr, sizeBytes); 189 } 190