1 //===- mlir-vulkan-runner.cpp - MLIR Vulkan Execution Driver --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a command line utility that executes an MLIR file on the Vulkan by
10 // translating MLIR GPU module to SPIR-V and host part to LLVM IR before
11 // JIT-compiling and executing the latter.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
16 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h"
17 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
18 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
19 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
20 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
21 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
22 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
25 #include "mlir/Dialect/GPU/Transforms/Passes.h"
26 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
27 #include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h"
28 #include "mlir/Dialect/MemRef/IR/MemRef.h"
29 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
30 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
31 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
32 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
33 #include "mlir/ExecutionEngine/JitRunner.h"
34 #include "mlir/ExecutionEngine/OptUtils.h"
35 #include "mlir/Pass/Pass.h"
36 #include "mlir/Pass/PassManager.h"
37 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
38 #include "mlir/Target/LLVMIR/Export.h"
39 #include "llvm/Support/InitLLVM.h"
40 #include "llvm/Support/TargetSelect.h"
41 
42 using namespace mlir;
43 
runMLIRPasses(ModuleOp module)44 static LogicalResult runMLIRPasses(ModuleOp module) {
45   PassManager passManager(module.getContext());
46   applyPassManagerCLOptions(passManager);
47 
48   passManager.addPass(createGpuKernelOutliningPass());
49   passManager.addPass(memref::createFoldSubViewOpsPass());
50   passManager.addPass(createConvertGPUToSPIRVPass());
51   OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();
52   modulePM.addPass(spirv::createLowerABIAttributesPass());
53   modulePM.addPass(spirv::createUpdateVersionCapabilityExtensionPass());
54   passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass());
55   LowerToLLVMOptions llvmOptions(module.getContext(), DataLayout(module));
56   passManager.addPass(createMemRefToLLVMPass());
57   passManager.nest<func::FuncOp>().addPass(LLVM::createRequestCWrappersPass());
58   passManager.addPass(createConvertFuncToLLVMPass(llvmOptions));
59   passManager.addPass(createReconcileUnrealizedCastsPass());
60   passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass());
61   return passManager.run(module);
62 }
63 
main(int argc,char ** argv)64 int main(int argc, char **argv) {
65   llvm::llvm_shutdown_obj x;
66   registerPassManagerCLOptions();
67 
68   llvm::InitLLVM y(argc, argv);
69   llvm::InitializeNativeTarget();
70   llvm::InitializeNativeTargetAsmPrinter();
71 
72   mlir::JitRunnerConfig jitRunnerConfig;
73   jitRunnerConfig.mlirTransformer = runMLIRPasses;
74 
75   mlir::DialectRegistry registry;
76   registry.insert<mlir::arith::ArithmeticDialect, mlir::LLVM::LLVMDialect,
77                   mlir::gpu::GPUDialect, mlir::spirv::SPIRVDialect,
78                   mlir::func::FuncDialect, mlir::memref::MemRefDialect>();
79   mlir::registerLLVMDialectTranslation(registry);
80 
81   return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig);
82 }
83