19d7be77bSChristian Sigg //===- CudaRuntimeWrappers.cpp - MLIR CUDA API wrapper library ------------===//
29d7be77bSChristian Sigg //
39d7be77bSChristian Sigg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
49d7be77bSChristian Sigg // See https://llvm.org/LICENSE.txt for license information.
59d7be77bSChristian Sigg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
69d7be77bSChristian Sigg //
79d7be77bSChristian Sigg //===----------------------------------------------------------------------===//
89d7be77bSChristian Sigg //
99d7be77bSChristian Sigg // Implements C wrappers around the CUDA library for easy linking in ORC jit.
109d7be77bSChristian Sigg // Also adds some debugging helpers that are helpful when writing MLIR code to
119d7be77bSChristian Sigg // run on GPUs.
129d7be77bSChristian Sigg //
139d7be77bSChristian Sigg //===----------------------------------------------------------------------===//
149d7be77bSChristian Sigg 
159d7be77bSChristian Sigg #include "mlir/ExecutionEngine/CRunnerUtils.h"
169d7be77bSChristian Sigg 
17b9f87e24SAart Bik #include <stdio.h>
18b9f87e24SAart Bik 
199d7be77bSChristian Sigg #include "cuda.h"
209d7be77bSChristian Sigg 
219d7be77bSChristian Sigg #ifdef _WIN32
229d7be77bSChristian Sigg #define MLIR_CUDA_WRAPPERS_EXPORT __declspec(dllexport)
239d7be77bSChristian Sigg #else
249d7be77bSChristian Sigg #define MLIR_CUDA_WRAPPERS_EXPORT
259d7be77bSChristian Sigg #endif // _WIN32
269d7be77bSChristian Sigg 
279d7be77bSChristian Sigg #define CUDA_REPORT_IF_ERROR(expr)                                             \
289d7be77bSChristian Sigg   [](CUresult result) {                                                        \
299d7be77bSChristian Sigg     if (!result)                                                               \
309d7be77bSChristian Sigg       return;                                                                  \
319d7be77bSChristian Sigg     const char *name = nullptr;                                                \
329d7be77bSChristian Sigg     cuGetErrorName(result, &name);                                             \
339d7be77bSChristian Sigg     if (!name)                                                                 \
349d7be77bSChristian Sigg       name = "<unknown>";                                                      \
359d7be77bSChristian Sigg     fprintf(stderr, "'%s' failed with '%s'\n", #expr, name);                   \
369d7be77bSChristian Sigg   }(expr)
379d7be77bSChristian Sigg 
3884718d37SKrzysztof Drewniak thread_local static int32_t defaultDevice = 0;
3984718d37SKrzysztof Drewniak 
4084718d37SKrzysztof Drewniak // Make the primary context of the current default device current for the
4184718d37SKrzysztof Drewniak // duration
4284718d37SKrzysztof Drewniak //  of the instance and restore the previous context on destruction.
439d7be77bSChristian Sigg class ScopedContext {
449d7be77bSChristian Sigg public:
ScopedContext()459d7be77bSChristian Sigg   ScopedContext() {
4684718d37SKrzysztof Drewniak     // Static reference to CUDA primary context for device ordinal
4784718d37SKrzysztof Drewniak     // defaultDevice.
48f69d5a7fSChristian Sigg     static CUcontext context = [] {
49f69d5a7fSChristian Sigg       CUDA_REPORT_IF_ERROR(cuInit(/*flags=*/0));
50f69d5a7fSChristian Sigg       CUdevice device;
5184718d37SKrzysztof Drewniak       CUDA_REPORT_IF_ERROR(cuDeviceGet(&device, /*ordinal=*/defaultDevice));
52f69d5a7fSChristian Sigg       CUcontext ctx;
53f69d5a7fSChristian Sigg       // Note: this does not affect the current context.
54f69d5a7fSChristian Sigg       CUDA_REPORT_IF_ERROR(cuDevicePrimaryCtxRetain(&ctx, device));
55f69d5a7fSChristian Sigg       return ctx;
56f69d5a7fSChristian Sigg     }();
57f69d5a7fSChristian Sigg 
58f69d5a7fSChristian Sigg     CUDA_REPORT_IF_ERROR(cuCtxPushCurrent(context));
599d7be77bSChristian Sigg   }
609d7be77bSChristian Sigg 
~ScopedContext()61f69d5a7fSChristian Sigg   ~ScopedContext() { CUDA_REPORT_IF_ERROR(cuCtxPopCurrent(nullptr)); }
629d7be77bSChristian Sigg };
639d7be77bSChristian Sigg 
mgpuModuleLoad(void * data)649d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUmodule mgpuModuleLoad(void *data) {
659d7be77bSChristian Sigg   ScopedContext scopedContext;
669d7be77bSChristian Sigg   CUmodule module = nullptr;
679d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuModuleLoadData(&module, data));
689d7be77bSChristian Sigg   return module;
699d7be77bSChristian Sigg }
709d7be77bSChristian Sigg 
mgpuModuleUnload(CUmodule module)719d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuModuleUnload(CUmodule module) {
729d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuModuleUnload(module));
739d7be77bSChristian Sigg }
749d7be77bSChristian Sigg 
759d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUfunction
mgpuModuleGetFunction(CUmodule module,const char * name)769d7be77bSChristian Sigg mgpuModuleGetFunction(CUmodule module, const char *name) {
779d7be77bSChristian Sigg   CUfunction function = nullptr;
789d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuModuleGetFunction(&function, module, name));
799d7be77bSChristian Sigg   return function;
809d7be77bSChristian Sigg }
819d7be77bSChristian Sigg 
829d7be77bSChristian Sigg // The wrapper uses intptr_t instead of CUDA's unsigned int to match
839d7be77bSChristian Sigg // the type of MLIR's index type. This avoids the need for casts in the
849d7be77bSChristian Sigg // generated MLIR code.
859d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
mgpuLaunchKernel(CUfunction function,intptr_t gridX,intptr_t gridY,intptr_t gridZ,intptr_t blockX,intptr_t blockY,intptr_t blockZ,int32_t smem,CUstream stream,void ** params,void ** extra)869d7be77bSChristian Sigg mgpuLaunchKernel(CUfunction function, intptr_t gridX, intptr_t gridY,
879d7be77bSChristian Sigg                  intptr_t gridZ, intptr_t blockX, intptr_t blockY,
889d7be77bSChristian Sigg                  intptr_t blockZ, int32_t smem, CUstream stream, void **params,
899d7be77bSChristian Sigg                  void **extra) {
909d7be77bSChristian Sigg   ScopedContext scopedContext;
919d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuLaunchKernel(function, gridX, gridY, gridZ, blockX,
929d7be77bSChristian Sigg                                       blockY, blockZ, smem, stream, params,
939d7be77bSChristian Sigg                                       extra));
949d7be77bSChristian Sigg }
959d7be77bSChristian Sigg 
mgpuStreamCreate()969d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUstream mgpuStreamCreate() {
979d7be77bSChristian Sigg   ScopedContext scopedContext;
989d7be77bSChristian Sigg   CUstream stream = nullptr;
999d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuStreamCreate(&stream, CU_STREAM_NON_BLOCKING));
1009d7be77bSChristian Sigg   return stream;
1019d7be77bSChristian Sigg }
1029d7be77bSChristian Sigg 
mgpuStreamDestroy(CUstream stream)1039d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamDestroy(CUstream stream) {
1049d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuStreamDestroy(stream));
1059d7be77bSChristian Sigg }
1069d7be77bSChristian Sigg 
1079d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
mgpuStreamSynchronize(CUstream stream)1089d7be77bSChristian Sigg mgpuStreamSynchronize(CUstream stream) {
1099d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuStreamSynchronize(stream));
1109d7be77bSChristian Sigg }
1119d7be77bSChristian Sigg 
mgpuStreamWaitEvent(CUstream stream,CUevent event)1129d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuStreamWaitEvent(CUstream stream,
1139d7be77bSChristian Sigg                                                               CUevent event) {
1149d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuStreamWaitEvent(stream, event, /*flags=*/0));
1159d7be77bSChristian Sigg }
1169d7be77bSChristian Sigg 
mgpuEventCreate()1179d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT CUevent mgpuEventCreate() {
1189d7be77bSChristian Sigg   ScopedContext scopedContext;
1199d7be77bSChristian Sigg   CUevent event = nullptr;
1209d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuEventCreate(&event, CU_EVENT_DISABLE_TIMING));
1219d7be77bSChristian Sigg   return event;
1229d7be77bSChristian Sigg }
1239d7be77bSChristian Sigg 
mgpuEventDestroy(CUevent event)1249d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuEventDestroy(CUevent event) {
1259d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuEventDestroy(event));
1269d7be77bSChristian Sigg }
1279d7be77bSChristian Sigg 
1289d7be77bSChristian Sigg extern MLIR_CUDA_WRAPPERS_EXPORT "C" void mgpuEventSynchronize(CUevent event) {
1299d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuEventSynchronize(event));
1309d7be77bSChristian Sigg }
1319d7be77bSChristian Sigg 
1329d7be77bSChristian Sigg extern MLIR_CUDA_WRAPPERS_EXPORT "C" void mgpuEventRecord(CUevent event,
1339d7be77bSChristian Sigg                                                           CUstream stream) {
1349d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuEventRecord(event, stream));
1359d7be77bSChristian Sigg }
1369d7be77bSChristian Sigg 
mgpuMemAlloc(uint64_t sizeBytes,CUstream)137*6b7e6ea4SMehdi Amini extern "C" void *mgpuMemAlloc(uint64_t sizeBytes, CUstream /*stream*/) {
1389d7be77bSChristian Sigg   ScopedContext scopedContext;
1399d7be77bSChristian Sigg   CUdeviceptr ptr;
1409d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuMemAlloc(&ptr, sizeBytes));
1419d7be77bSChristian Sigg   return reinterpret_cast<void *>(ptr);
1429d7be77bSChristian Sigg }
1439d7be77bSChristian Sigg 
mgpuMemFree(void * ptr,CUstream)144*6b7e6ea4SMehdi Amini extern "C" void mgpuMemFree(void *ptr, CUstream /*stream*/) {
1459d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuMemFree(reinterpret_cast<CUdeviceptr>(ptr)));
1469d7be77bSChristian Sigg }
1479d7be77bSChristian Sigg 
mgpuMemcpy(void * dst,void * src,size_t sizeBytes,CUstream stream)148361458b1SLoren Maggiore extern "C" void mgpuMemcpy(void *dst, void *src, size_t sizeBytes,
1499d7be77bSChristian Sigg                            CUstream stream) {
1509d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuMemcpyAsync(reinterpret_cast<CUdeviceptr>(dst),
1519d7be77bSChristian Sigg                                      reinterpret_cast<CUdeviceptr>(src),
1529d7be77bSChristian Sigg                                      sizeBytes, stream));
1539d7be77bSChristian Sigg }
1549d7be77bSChristian Sigg 
mgpuMemset32(void * dst,unsigned int value,size_t count,CUstream stream)155361458b1SLoren Maggiore extern "C" void mgpuMemset32(void *dst, unsigned int value, size_t count,
156361458b1SLoren Maggiore                              CUstream stream) {
157361458b1SLoren Maggiore   CUDA_REPORT_IF_ERROR(cuMemsetD32Async(reinterpret_cast<CUdeviceptr>(dst),
158361458b1SLoren Maggiore                                         value, count, stream));
159361458b1SLoren Maggiore }
160361458b1SLoren Maggiore 
1619d7be77bSChristian Sigg /// Helper functions for writing mlir example code
1629d7be77bSChristian Sigg 
1639d7be77bSChristian Sigg // Allows to register byte array with the CUDA runtime. Helpful until we have
1649d7be77bSChristian Sigg // transfer functions implemented.
1659d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
mgpuMemHostRegister(void * ptr,uint64_t sizeBytes)1669d7be77bSChristian Sigg mgpuMemHostRegister(void *ptr, uint64_t sizeBytes) {
1679d7be77bSChristian Sigg   ScopedContext scopedContext;
1689d7be77bSChristian Sigg   CUDA_REPORT_IF_ERROR(cuMemHostRegister(ptr, sizeBytes, /*flags=*/0));
1699d7be77bSChristian Sigg }
1709d7be77bSChristian Sigg 
1714edc9e2aSUday Bondhugula /// Registers a memref with the CUDA runtime. `descriptor` is a pointer to a
1724edc9e2aSUday Bondhugula /// ranked memref descriptor struct of rank `rank`. Helpful until we have
1734edc9e2aSUday Bondhugula /// transfer functions implemented.
1749d7be77bSChristian Sigg extern "C" MLIR_CUDA_WRAPPERS_EXPORT void
mgpuMemHostRegisterMemRef(int64_t rank,StridedMemRefType<char,1> * descriptor,int64_t elementSizeBytes)1759d7be77bSChristian Sigg mgpuMemHostRegisterMemRef(int64_t rank, StridedMemRefType<char, 1> *descriptor,
1769d7be77bSChristian Sigg                           int64_t elementSizeBytes) {
1779d7be77bSChristian Sigg   // Only densely packed tensors are currently supported.
1784edc9e2aSUday Bondhugula   int64_t *denseStrides = (int64_t *)alloca(rank * sizeof(int64_t));
1794edc9e2aSUday Bondhugula   int64_t *sizes = descriptor->sizes;
1804edc9e2aSUday Bondhugula   for (int64_t i = rank - 1, runningStride = 1; i >= 0; i--) {
1814edc9e2aSUday Bondhugula     denseStrides[i] = runningStride;
1824edc9e2aSUday Bondhugula     runningStride *= sizes[i];
1834edc9e2aSUday Bondhugula   }
1844edc9e2aSUday Bondhugula   uint64_t sizeBytes = sizes[0] * denseStrides[0] * elementSizeBytes;
1854edc9e2aSUday Bondhugula   int64_t *strides = &sizes[rank];
186012c0cc7SNicolas Vasilache   (void)strides;
1874edc9e2aSUday Bondhugula   for (unsigned i = 0; i < rank; ++i)
1884edc9e2aSUday Bondhugula     assert(strides[i] == denseStrides[i] &&
1894edc9e2aSUday Bondhugula            "Mismatch in computed dense strides");
1909d7be77bSChristian Sigg 
1914edc9e2aSUday Bondhugula   auto *ptr = descriptor->data + descriptor->offset * elementSizeBytes;
1929d7be77bSChristian Sigg   mgpuMemHostRegister(ptr, sizeBytes);
1939d7be77bSChristian Sigg }
19484718d37SKrzysztof Drewniak 
mgpuSetDefaultDevice(int32_t device)19584718d37SKrzysztof Drewniak extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuSetDefaultDevice(int32_t device) {
19684718d37SKrzysztof Drewniak   defaultDevice = device;
19784718d37SKrzysztof Drewniak }
198