1 //===- OptUtils.cpp - MLIR Execution Engine optimization pass utilities ---===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file implements the utility functions to trigger LLVM optimizations from
19 // MLIR Execution Engine.
20 //
21 //===----------------------------------------------------------------------===//
22 
23 #include "mlir/ExecutionEngine/OptUtils.h"
24 
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/IR/LegacyPassManager.h"
27 #include "llvm/IR/LegacyPassNameParser.h"
28 #include "llvm/IR/Module.h"
29 #include "llvm/InitializePasses.h"
30 #include "llvm/Pass.h"
31 #include "llvm/Support/Allocator.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/Error.h"
34 #include "llvm/Support/StringSaver.h"
35 #include "llvm/Transforms/IPO.h"
36 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
37 #include <climits>
38 #include <mutex>
39 
40 // Run the module and function passes managed by the module manager.
41 static void runPasses(llvm::legacy::PassManager &modulePM,
42                       llvm::legacy::FunctionPassManager &funcPM,
43                       llvm::Module &m) {
44   funcPM.doInitialization();
45   for (auto &func : m) {
46     funcPM.run(func);
47   }
48   funcPM.doFinalization();
49   modulePM.run(m);
50 }
51 
52 // Initialize basic LLVM transformation passes under lock.
53 void mlir::initializeLLVMPasses() {
54   static std::mutex mutex;
55   std::lock_guard<std::mutex> lock(mutex);
56 
57   auto &registry = *llvm::PassRegistry::getPassRegistry();
58   llvm::initializeCore(registry);
59   llvm::initializeTransformUtils(registry);
60   llvm::initializeScalarOpts(registry);
61   llvm::initializeIPO(registry);
62   llvm::initializeInstCombine(registry);
63   llvm::initializeAggressiveInstCombine(registry);
64   llvm::initializeAnalysis(registry);
65   llvm::initializeVectorization(registry);
66 }
67 
68 // Populate pass managers according to the optimization and size levels.
69 // This behaves similarly to LLVM opt.
70 static void populatePassManagers(llvm::legacy::PassManager &modulePM,
71                                  llvm::legacy::FunctionPassManager &funcPM,
72                                  unsigned optLevel, unsigned sizeLevel) {
73   llvm::PassManagerBuilder builder;
74   builder.OptLevel = optLevel;
75   builder.SizeLevel = sizeLevel;
76   builder.Inliner = llvm::createFunctionInliningPass(
77       optLevel, sizeLevel, /*DisableInlineHotCallSite=*/false);
78   builder.LoopVectorize = optLevel > 1 && sizeLevel < 2;
79   builder.SLPVectorize = optLevel > 1 && sizeLevel < 2;
80   builder.DisableUnrollLoops = (optLevel == 0);
81 
82   builder.populateModulePassManager(modulePM);
83   builder.populateFunctionPassManager(funcPM);
84 }
85 
86 // Create and return a lambda that uses LLVM pass manager builder to set up
87 // optimizations based on the given level.
88 std::function<llvm::Error(llvm::Module *)>
89 mlir::makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel) {
90   return [optLevel, sizeLevel](llvm::Module *m) -> llvm::Error {
91 
92     llvm::legacy::PassManager modulePM;
93     llvm::legacy::FunctionPassManager funcPM(m);
94     populatePassManagers(modulePM, funcPM, optLevel, sizeLevel);
95     runPasses(modulePM, funcPM, *m);
96 
97     return llvm::Error::success();
98   };
99 }
100 
101 // Create and return a lambda that is given a set of passes to run, plus an
102 // optional optimization level to pre-populate the pass manager.
103 std::function<llvm::Error(llvm::Module *)> mlir::makeLLVMPassesTransformer(
104     llvm::ArrayRef<const llvm::PassInfo *> llvmPasses,
105     llvm::Optional<unsigned> mbOptLevel, unsigned optPassesInsertPos) {
106   return [llvmPasses, mbOptLevel,
107           optPassesInsertPos](llvm::Module *m) -> llvm::Error {
108     llvm::legacy::PassManager modulePM;
109     llvm::legacy::FunctionPassManager funcPM(m);
110 
111     bool insertOptPasses = mbOptLevel.hasValue();
112     for (unsigned i = 0, e = llvmPasses.size(); i < e; ++i) {
113       const auto *passInfo = llvmPasses[i];
114       if (!passInfo->getNormalCtor())
115         continue;
116 
117       if (insertOptPasses && optPassesInsertPos == i) {
118         populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0);
119         insertOptPasses = false;
120       }
121 
122       auto *pass = passInfo->createPass();
123       if (!pass)
124         return llvm::make_error<llvm::StringError>(
125             "could not create pass " + passInfo->getPassName(),
126             llvm::inconvertibleErrorCode());
127       modulePM.add(pass);
128     }
129 
130     if (insertOptPasses)
131       populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0);
132 
133     runPasses(modulePM, funcPM, *m);
134     return llvm::Error::success();
135   };
136 }
137