1 //===- mlir-vulkan-runner.cpp - MLIR Vulkan Execution Driver --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a command line utility that executes an MLIR file on the Vulkan by
10 // translating MLIR GPU module to SPIR-V and host part to LLVM IR before
11 // JIT-compiling and executing the latter.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
16 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
17 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
18 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
19 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
20 #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
21 #include "mlir/Conversion/StandardToSPIRV/StandardToSPIRVPass.h"
22 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
23 #include "mlir/Dialect/GPU/GPUDialect.h"
24 #include "mlir/Dialect/GPU/Passes.h"
25 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
28 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
29 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
30 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
31 #include "mlir/Dialect/StandardOps/IR/Ops.h"
32 #include "mlir/ExecutionEngine/JitRunner.h"
33 #include "mlir/ExecutionEngine/OptUtils.h"
34 #include "mlir/Pass/Pass.h"
35 #include "mlir/Pass/PassManager.h"
36 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
37 #include "mlir/Target/LLVMIR/Export.h"
38 #include "llvm/Support/InitLLVM.h"
39 #include "llvm/Support/TargetSelect.h"
40 
41 using namespace mlir;
42 
43 static LogicalResult runMLIRPasses(ModuleOp module) {
44   PassManager passManager(module.getContext());
45   applyPassManagerCLOptions(passManager);
46 
47   passManager.addPass(createGpuKernelOutliningPass());
48   passManager.addPass(memref::createFoldSubViewOpsPass());
49   passManager.addPass(createConvertGPUToSPIRVPass());
50   OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();
51   modulePM.addPass(spirv::createLowerABIAttributesPass());
52   modulePM.addPass(spirv::createUpdateVersionCapabilityExtensionPass());
53   passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass());
54   LowerToLLVMOptions llvmOptions(module.getContext(), DataLayout(module));
55   llvmOptions.emitCWrappers = true;
56   passManager.addPass(createMemRefToLLVMPass());
57   passManager.addPass(createLowerToLLVMPass(llvmOptions));
58   passManager.addPass(createReconcileUnrealizedCastsPass());
59   passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass());
60   return passManager.run(module);
61 }
62 
63 int main(int argc, char **argv) {
64   llvm::llvm_shutdown_obj x;
65   registerPassManagerCLOptions();
66 
67   llvm::InitLLVM y(argc, argv);
68   llvm::InitializeNativeTarget();
69   llvm::InitializeNativeTargetAsmPrinter();
70   mlir::initializeLLVMPasses();
71 
72   mlir::JitRunnerConfig jitRunnerConfig;
73   jitRunnerConfig.mlirTransformer = runMLIRPasses;
74 
75   mlir::DialectRegistry registry;
76   registry.insert<mlir::arith::ArithmeticDialect, mlir::LLVM::LLVMDialect,
77                   mlir::gpu::GPUDialect, mlir::spirv::SPIRVDialect,
78                   mlir::StandardOpsDialect, mlir::memref::MemRefDialect>();
79   mlir::registerLLVMDialectTranslation(registry);
80 
81   return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig);
82 }
83