12224221fSChristian Sigg //===- LowerGPUToCUBIN.cpp - Convert GPU kernel to CUBIN blob -------------===//
22224221fSChristian Sigg //
32224221fSChristian Sigg // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
42224221fSChristian Sigg // See https://llvm.org/LICENSE.txt for license information.
52224221fSChristian Sigg // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
62224221fSChristian Sigg //
72224221fSChristian Sigg //===----------------------------------------------------------------------===//
82224221fSChristian Sigg //
92224221fSChristian Sigg // This file implements a pass that serializes a gpu module into CUBIN blob and
102224221fSChristian Sigg // adds that blob as a string attribute of the module.
112224221fSChristian Sigg //
122224221fSChristian Sigg //===----------------------------------------------------------------------===//
13d7ef488bSMogball
14d7ef488bSMogball #include "mlir/Dialect/GPU/Transforms/Passes.h"
152224221fSChristian Sigg
162224221fSChristian Sigg #if MLIR_GPU_TO_CUBIN_PASS_ENABLE
172224221fSChristian Sigg #include "mlir/Pass/Pass.h"
182224221fSChristian Sigg #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
192224221fSChristian Sigg #include "mlir/Target/LLVMIR/Export.h"
202224221fSChristian Sigg #include "llvm/Support/TargetSelect.h"
212224221fSChristian Sigg
222224221fSChristian Sigg #include <cuda.h>
232224221fSChristian Sigg
242224221fSChristian Sigg using namespace mlir;
252224221fSChristian Sigg
emitCudaError(const llvm::Twine & expr,const char * buffer,CUresult result,Location loc)262224221fSChristian Sigg static void emitCudaError(const llvm::Twine &expr, const char *buffer,
272224221fSChristian Sigg CUresult result, Location loc) {
282224221fSChristian Sigg const char *error;
292224221fSChristian Sigg cuGetErrorString(result, &error);
302224221fSChristian Sigg emitError(loc, expr.concat(" failed with error code ")
312224221fSChristian Sigg .concat(llvm::Twine{error})
322224221fSChristian Sigg .concat("[")
332224221fSChristian Sigg .concat(buffer)
342224221fSChristian Sigg .concat("]"));
352224221fSChristian Sigg }
362224221fSChristian Sigg
372224221fSChristian Sigg #define RETURN_ON_CUDA_ERROR(expr) \
382224221fSChristian Sigg do { \
392224221fSChristian Sigg if (auto status = (expr)) { \
402224221fSChristian Sigg emitCudaError(#expr, jitErrorBuffer, status, loc); \
412224221fSChristian Sigg return {}; \
422224221fSChristian Sigg } \
432224221fSChristian Sigg } while (false)
442224221fSChristian Sigg
452224221fSChristian Sigg namespace {
462224221fSChristian Sigg class SerializeToCubinPass
472224221fSChristian Sigg : public PassWrapper<SerializeToCubinPass, gpu::SerializeToBlobPass> {
482224221fSChristian Sigg public:
491269f96dSRiver Riddle MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(SerializeToCubinPass)
501269f96dSRiver Riddle
512224221fSChristian Sigg SerializeToCubinPass();
522224221fSChristian Sigg
getArgument() const53b5e22e6dSMehdi Amini StringRef getArgument() const override { return "gpu-to-cubin"; }
getDescription() const54b5e22e6dSMehdi Amini StringRef getDescription() const override {
55b5e22e6dSMehdi Amini return "Lower GPU kernel function to CUBIN binary annotations";
56b5e22e6dSMehdi Amini }
57b5e22e6dSMehdi Amini
582224221fSChristian Sigg private:
592224221fSChristian Sigg void getDependentDialects(DialectRegistry ®istry) const override;
602224221fSChristian Sigg
612224221fSChristian Sigg // Serializes PTX to CUBIN.
622224221fSChristian Sigg std::unique_ptr<std::vector<char>>
632224221fSChristian Sigg serializeISA(const std::string &isa) override;
642224221fSChristian Sigg };
652224221fSChristian Sigg } // namespace
662224221fSChristian Sigg
672224221fSChristian Sigg // Sets the 'option' to 'value' unless it already has a value.
maybeSetOption(Pass::Option<std::string> & option,const char * value)682224221fSChristian Sigg static void maybeSetOption(Pass::Option<std::string> &option,
692224221fSChristian Sigg const char *value) {
702224221fSChristian Sigg if (!option.hasValue())
712224221fSChristian Sigg option = value;
722224221fSChristian Sigg }
732224221fSChristian Sigg
SerializeToCubinPass()742224221fSChristian Sigg SerializeToCubinPass::SerializeToCubinPass() {
752224221fSChristian Sigg maybeSetOption(this->triple, "nvptx64-nvidia-cuda");
762224221fSChristian Sigg maybeSetOption(this->chip, "sm_35");
772224221fSChristian Sigg maybeSetOption(this->features, "+ptx60");
782224221fSChristian Sigg }
792224221fSChristian Sigg
getDependentDialects(DialectRegistry & registry) const802224221fSChristian Sigg void SerializeToCubinPass::getDependentDialects(
812224221fSChristian Sigg DialectRegistry ®istry) const {
822224221fSChristian Sigg registerNVVMDialectTranslation(registry);
832224221fSChristian Sigg gpu::SerializeToBlobPass::getDependentDialects(registry);
842224221fSChristian Sigg }
852224221fSChristian Sigg
862224221fSChristian Sigg std::unique_ptr<std::vector<char>>
serializeISA(const std::string & isa)872224221fSChristian Sigg SerializeToCubinPass::serializeISA(const std::string &isa) {
882224221fSChristian Sigg Location loc = getOperation().getLoc();
892224221fSChristian Sigg char jitErrorBuffer[4096] = {0};
902224221fSChristian Sigg
912224221fSChristian Sigg RETURN_ON_CUDA_ERROR(cuInit(0));
922224221fSChristian Sigg
932224221fSChristian Sigg // Linking requires a device context.
942224221fSChristian Sigg CUdevice device;
952224221fSChristian Sigg RETURN_ON_CUDA_ERROR(cuDeviceGet(&device, 0));
962224221fSChristian Sigg CUcontext context;
972224221fSChristian Sigg RETURN_ON_CUDA_ERROR(cuCtxCreate(&context, 0, device));
982224221fSChristian Sigg CUlinkState linkState;
992224221fSChristian Sigg
1002224221fSChristian Sigg CUjit_option jitOptions[] = {CU_JIT_ERROR_LOG_BUFFER,
1012224221fSChristian Sigg CU_JIT_ERROR_LOG_BUFFER_SIZE_BYTES};
1022224221fSChristian Sigg void *jitOptionsVals[] = {jitErrorBuffer,
1032224221fSChristian Sigg reinterpret_cast<void *>(sizeof(jitErrorBuffer))};
1042224221fSChristian Sigg
1052224221fSChristian Sigg RETURN_ON_CUDA_ERROR(cuLinkCreate(2, /* number of jit options */
1062224221fSChristian Sigg jitOptions, /* jit options */
1072224221fSChristian Sigg jitOptionsVals, /* jit option values */
1082224221fSChristian Sigg &linkState));
1092224221fSChristian Sigg
1102224221fSChristian Sigg auto kernelName = getOperation().getName().str();
1112224221fSChristian Sigg RETURN_ON_CUDA_ERROR(cuLinkAddData(
1122224221fSChristian Sigg linkState, CUjitInputType::CU_JIT_INPUT_PTX,
1132224221fSChristian Sigg const_cast<void *>(static_cast<const void *>(isa.c_str())), isa.length(),
1142224221fSChristian Sigg kernelName.c_str(), 0, /* number of jit options */
1152224221fSChristian Sigg nullptr, /* jit options */
1162224221fSChristian Sigg nullptr /* jit option values */
1172224221fSChristian Sigg ));
1182224221fSChristian Sigg
1192224221fSChristian Sigg void *cubinData;
1202224221fSChristian Sigg size_t cubinSize;
1212224221fSChristian Sigg RETURN_ON_CUDA_ERROR(cuLinkComplete(linkState, &cubinData, &cubinSize));
1222224221fSChristian Sigg
1232224221fSChristian Sigg char *cubinAsChar = static_cast<char *>(cubinData);
1242224221fSChristian Sigg auto result =
1252224221fSChristian Sigg std::make_unique<std::vector<char>>(cubinAsChar, cubinAsChar + cubinSize);
1262224221fSChristian Sigg
1272224221fSChristian Sigg // This will also destroy the cubin data.
1282224221fSChristian Sigg RETURN_ON_CUDA_ERROR(cuLinkDestroy(linkState));
1292224221fSChristian Sigg RETURN_ON_CUDA_ERROR(cuCtxDestroy(context));
1302224221fSChristian Sigg
1312224221fSChristian Sigg return result;
1322224221fSChristian Sigg }
1332224221fSChristian Sigg
1342224221fSChristian Sigg // Register pass to serialize GPU kernel functions to a CUBIN binary annotation.
registerGpuSerializeToCubinPass()1352224221fSChristian Sigg void mlir::registerGpuSerializeToCubinPass() {
136*b7f93c28SJeff Niu PassRegistration<SerializeToCubinPass> registerSerializeToCubin([] {
1372224221fSChristian Sigg // Initialize LLVM NVPTX backend.
1382224221fSChristian Sigg LLVMInitializeNVPTXTarget();
1392224221fSChristian Sigg LLVMInitializeNVPTXTargetInfo();
1402224221fSChristian Sigg LLVMInitializeNVPTXTargetMC();
1412224221fSChristian Sigg LLVMInitializeNVPTXAsmPrinter();
1422224221fSChristian Sigg
1432224221fSChristian Sigg return std::make_unique<SerializeToCubinPass>();
1442224221fSChristian Sigg });
1452224221fSChristian Sigg }
1462224221fSChristian Sigg #else // MLIR_GPU_TO_CUBIN_PASS_ENABLE
registerGpuSerializeToCubinPass()1472224221fSChristian Sigg void mlir::registerGpuSerializeToCubinPass() {}
1482224221fSChristian Sigg #endif // MLIR_GPU_TO_CUBIN_PASS_ENABLE
149