1 //===- TestConvertGPUKernelToCubin.cpp - Test gpu kernel cubin lowering ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/GPU/Passes.h"
10 
11 #include "mlir/Pass/Pass.h"
12 #include "mlir/Target/LLVMIR/Dialect/NVVM/NVVMToLLVMIRTranslation.h"
13 #include "mlir/Target/LLVMIR/Export.h"
14 #include "llvm/Support/TargetSelect.h"
15 
16 using namespace mlir;
17 
18 #if MLIR_CUDA_CONVERSIONS_ENABLED
19 namespace {
20 class TestSerializeToCubinPass
21     : public PassWrapper<TestSerializeToCubinPass, gpu::SerializeToBlobPass> {
22 public:
23   TestSerializeToCubinPass();
24 
25 private:
26   void getDependentDialects(DialectRegistry &registry) const override;
27 
28   // Serializes PTX to CUBIN.
29   std::unique_ptr<std::vector<char>>
30   serializeISA(const std::string &isa) override;
31 };
32 } // namespace
33 
34 TestSerializeToCubinPass::TestSerializeToCubinPass() {
35   this->triple = "nvptx64-nvidia-cuda";
36   this->chip = "sm_35";
37   this->features = "+ptx60";
38 }
39 
40 void TestSerializeToCubinPass::getDependentDialects(
41     DialectRegistry &registry) const {
42   registerNVVMDialectTranslation(registry);
43   gpu::SerializeToBlobPass::getDependentDialects(registry);
44 }
45 
46 std::unique_ptr<std::vector<char>>
47 TestSerializeToCubinPass::serializeISA(const std::string &) {
48   std::string data = "CUBIN";
49   return std::make_unique<std::vector<char>>(data.begin(), data.end());
50 }
51 
52 namespace mlir {
53 namespace test {
54 // Register test pass to serialize GPU module to a CUBIN binary annotation.
55 void registerTestGpuSerializeToCubinPass() {
56   PassRegistration<TestSerializeToCubinPass> registerSerializeToCubin(
57       "test-gpu-to-cubin",
58       "Lower GPU kernel function to CUBIN binary annotations", [] {
59         // Initialize LLVM NVPTX backend.
60         LLVMInitializeNVPTXTarget();
61         LLVMInitializeNVPTXTargetInfo();
62         LLVMInitializeNVPTXTargetMC();
63         LLVMInitializeNVPTXAsmPrinter();
64 
65         return std::make_unique<TestSerializeToCubinPass>();
66       });
67 }
68 } // namespace test
69 } // namespace mlir
70 #endif // MLIR_CUDA_CONVERSIONS_ENABLED
71