1 //===- OptUtils.cpp - MLIR Execution Engine optimization pass utilities ---===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file implements the utility functions to trigger LLVM optimizations from
19 // MLIR Execution Engine.
20 //
21 //===----------------------------------------------------------------------===//
22 
23 #include "mlir/ExecutionEngine/OptUtils.h"
24 
25 #include "llvm/IR/LegacyPassManager.h"
26 #include "llvm/IR/LegacyPassNameParser.h"
27 #include "llvm/IR/Module.h"
28 #include "llvm/InitializePasses.h"
29 #include "llvm/Pass.h"
30 #include "llvm/Support/Allocator.h"
31 #include "llvm/Support/CommandLine.h"
32 #include "llvm/Support/Error.h"
33 #include "llvm/Support/StringSaver.h"
34 #include "llvm/Transforms/IPO.h"
35 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
36 #include <mutex>
37 
38 // Run the module and function passes managed by the module manager.
39 static void runPasses(llvm::legacy::PassManager &modulePM,
40                       llvm::legacy::FunctionPassManager &funcPM,
41                       llvm::Module &m) {
42   for (auto &func : m) {
43     funcPM.run(func);
44   }
45   modulePM.run(m);
46 }
47 
48 // Initialize basic LLVM transformation passes under lock.
49 void mlir::initializeLLVMPasses() {
50   static std::mutex mutex;
51   std::lock_guard<std::mutex> lock(mutex);
52 
53   auto &registry = *llvm::PassRegistry::getPassRegistry();
54   llvm::initializeCore(registry);
55   llvm::initializeTransformUtils(registry);
56   llvm::initializeScalarOpts(registry);
57   llvm::initializeIPO(registry);
58   llvm::initializeInstCombine(registry);
59   llvm::initializeAggressiveInstCombine(registry);
60   llvm::initializeAnalysis(registry);
61   llvm::initializeVectorization(registry);
62 }
63 
64 // Create and return a lambda that uses LLVM pass manager builder to set up
65 // optimizations based on the given level.
66 std::function<llvm::Error(llvm::Module *)>
67 mlir::makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel) {
68   return [optLevel, sizeLevel](llvm::Module *m) -> llvm::Error {
69     llvm::PassManagerBuilder builder;
70     builder.OptLevel = optLevel;
71     builder.SizeLevel = sizeLevel;
72     builder.Inliner = llvm::createFunctionInliningPass(
73         optLevel, sizeLevel, /*DisableInlineHotCallSite=*/false);
74 
75     llvm::legacy::PassManager modulePM;
76     llvm::legacy::FunctionPassManager funcPM(m);
77     builder.populateModulePassManager(modulePM);
78     builder.populateFunctionPassManager(funcPM);
79     runPasses(modulePM, funcPM, *m);
80 
81     return llvm::Error::success();
82   };
83 }
84 
85 // Create and return a lambda that leverages LLVM PassInfo command line parser
86 // to construct passes given the command line flags that come from the given
87 // string rather than from the command line.
88 std::function<llvm::Error(llvm::Module *)>
89 mlir::makeLLVMPassesTransformer(std::string config) {
90   return [config](llvm::Module *m) -> llvm::Error {
91     static llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser>
92         llvmPasses(llvm::cl::desc("LLVM optimizing passes to run"));
93     llvm::BumpPtrAllocator allocator;
94     llvm::StringSaver saver(allocator);
95     llvm::SmallVector<const char *, 16> args;
96     args.push_back(""); // inject dummy program name
97     llvm::cl::TokenizeGNUCommandLine(config, saver, args);
98     llvm::cl::ParseCommandLineOptions(args.size(), args.data());
99 
100     llvm::legacy::PassManager modulePM;
101 
102     for (const auto *passInfo : llvmPasses) {
103       if (!passInfo->getNormalCtor())
104         continue;
105 
106       auto *pass = passInfo->createPass();
107       if (!pass)
108         return llvm::make_error<llvm::StringError>(
109             "could not create pass " + passInfo->getPassName(),
110             llvm::inconvertibleErrorCode());
111 
112       modulePM.add(pass);
113     }
114 
115     modulePM.run(*m);
116     return llvm::Error::success();
117   };
118 }
119