1 //===- mlir-vulkan-runner.cpp - MLIR Vulkan Execution Driver --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a command line utility that executes an MLIR file on the Vulkan by
10 // translating MLIR GPU module to SPIR-V and host part to LLVM IR before
11 // JIT-compiling and executing the latter.
12 //
13 //===----------------------------------------------------------------------===//
14
15 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
16 #include "mlir/Conversion/FuncToSPIRV/FuncToSPIRVPass.h"
17 #include "mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
18 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
19 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
20 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
21 #include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h"
22 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"
24 #include "mlir/Dialect/GPU/IR/GPUDialect.h"
25 #include "mlir/Dialect/GPU/Transforms/Passes.h"
26 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
27 #include "mlir/Dialect/LLVMIR/Transforms/RequestCWrappers.h"
28 #include "mlir/Dialect/MemRef/IR/MemRef.h"
29 #include "mlir/Dialect/MemRef/Transforms/Passes.h"
30 #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
31 #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
32 #include "mlir/Dialect/SPIRV/Transforms/Passes.h"
33 #include "mlir/ExecutionEngine/JitRunner.h"
34 #include "mlir/ExecutionEngine/OptUtils.h"
35 #include "mlir/Pass/Pass.h"
36 #include "mlir/Pass/PassManager.h"
37 #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h"
38 #include "mlir/Target/LLVMIR/Export.h"
39 #include "llvm/Support/InitLLVM.h"
40 #include "llvm/Support/TargetSelect.h"
41
42 using namespace mlir;
43
runMLIRPasses(ModuleOp module)44 static LogicalResult runMLIRPasses(ModuleOp module) {
45 PassManager passManager(module.getContext());
46 applyPassManagerCLOptions(passManager);
47
48 passManager.addPass(createGpuKernelOutliningPass());
49 passManager.addPass(memref::createFoldSubViewOpsPass());
50 passManager.addPass(createConvertGPUToSPIRVPass());
51 OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();
52 modulePM.addPass(spirv::createLowerABIAttributesPass());
53 modulePM.addPass(spirv::createUpdateVersionCapabilityExtensionPass());
54 passManager.addPass(createConvertGpuLaunchFuncToVulkanLaunchFuncPass());
55 LowerToLLVMOptions llvmOptions(module.getContext(), DataLayout(module));
56 passManager.addPass(createMemRefToLLVMPass());
57 passManager.nest<func::FuncOp>().addPass(LLVM::createRequestCWrappersPass());
58 passManager.addPass(createConvertFuncToLLVMPass(llvmOptions));
59 passManager.addPass(createReconcileUnrealizedCastsPass());
60 passManager.addPass(createConvertVulkanLaunchFuncToVulkanCallsPass());
61 return passManager.run(module);
62 }
63
main(int argc,char ** argv)64 int main(int argc, char **argv) {
65 llvm::llvm_shutdown_obj x;
66 registerPassManagerCLOptions();
67
68 llvm::InitLLVM y(argc, argv);
69 llvm::InitializeNativeTarget();
70 llvm::InitializeNativeTargetAsmPrinter();
71
72 mlir::JitRunnerConfig jitRunnerConfig;
73 jitRunnerConfig.mlirTransformer = runMLIRPasses;
74
75 mlir::DialectRegistry registry;
76 registry.insert<mlir::arith::ArithmeticDialect, mlir::LLVM::LLVMDialect,
77 mlir::gpu::GPUDialect, mlir::spirv::SPIRVDialect,
78 mlir::func::FuncDialect, mlir::memref::MemRefDialect>();
79 mlir::registerLLVMDialectTranslation(registry);
80
81 return mlir::JitRunnerMain(argc, argv, registry, jitRunnerConfig);
82 }
83