1 //===- OptUtils.cpp - MLIR Execution Engine optimization pass utilities ---===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file implements the utility functions to trigger LLVM optimizations from
19 // MLIR Execution Engine.
20 //
21 //===----------------------------------------------------------------------===//
22 
23 #include "mlir/ExecutionEngine/OptUtils.h"
24 
25 #include "llvm/ADT/ArrayRef.h"
26 #include "llvm/Analysis/TargetTransformInfo.h"
27 #include "llvm/IR/LegacyPassManager.h"
28 #include "llvm/IR/LegacyPassNameParser.h"
29 #include "llvm/IR/Module.h"
30 #include "llvm/InitializePasses.h"
31 #include "llvm/Pass.h"
32 #include "llvm/Support/Allocator.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/Error.h"
35 #include "llvm/Support/StringSaver.h"
36 #include "llvm/Target/TargetMachine.h"
37 #include "llvm/Transforms/IPO.h"
38 #include "llvm/Transforms/IPO/PassManagerBuilder.h"
39 #include <climits>
40 #include <mutex>
41 
42 // Run the module and function passes managed by the module manager.
43 static void runPasses(llvm::legacy::PassManager &modulePM,
44                       llvm::legacy::FunctionPassManager &funcPM,
45                       llvm::Module &m) {
46   funcPM.doInitialization();
47   for (auto &func : m) {
48     funcPM.run(func);
49   }
50   funcPM.doFinalization();
51   modulePM.run(m);
52 }
53 
54 // Initialize basic LLVM transformation passes under lock.
55 void mlir::initializeLLVMPasses() {
56   static std::mutex mutex;
57   std::lock_guard<std::mutex> lock(mutex);
58 
59   auto &registry = *llvm::PassRegistry::getPassRegistry();
60   llvm::initializeCore(registry);
61   llvm::initializeTransformUtils(registry);
62   llvm::initializeScalarOpts(registry);
63   llvm::initializeIPO(registry);
64   llvm::initializeInstCombine(registry);
65   llvm::initializeAggressiveInstCombine(registry);
66   llvm::initializeAnalysis(registry);
67   llvm::initializeVectorization(registry);
68 }
69 
70 // Populate pass managers according to the optimization and size levels.
71 // This behaves similarly to LLVM opt.
72 static void populatePassManagers(llvm::legacy::PassManager &modulePM,
73                                  llvm::legacy::FunctionPassManager &funcPM,
74                                  unsigned optLevel, unsigned sizeLevel,
75                                  llvm::TargetMachine *targetMachine) {
76   llvm::PassManagerBuilder builder;
77   builder.OptLevel = optLevel;
78   builder.SizeLevel = sizeLevel;
79   builder.Inliner = llvm::createFunctionInliningPass(
80       optLevel, sizeLevel, /*DisableInlineHotCallSite=*/false);
81   builder.LoopVectorize = optLevel > 1 && sizeLevel < 2;
82   builder.SLPVectorize = optLevel > 1 && sizeLevel < 2;
83   builder.DisableUnrollLoops = (optLevel == 0);
84 
85   if (targetMachine) {
86     // Add pass to initialize TTI for this specific target. Otherwise, TTI will
87     // be initialized to NoTTIImpl by defaul.
88     modulePM.add(createTargetTransformInfoWrapperPass(
89         targetMachine->getTargetIRAnalysis()));
90     funcPM.add(createTargetTransformInfoWrapperPass(
91         targetMachine->getTargetIRAnalysis()));
92   }
93 
94   builder.populateModulePassManager(modulePM);
95   builder.populateFunctionPassManager(funcPM);
96 }
97 
98 // Create and return a lambda that uses LLVM pass manager builder to set up
99 // optimizations based on the given level.
100 std::function<llvm::Error(llvm::Module *)>
101 mlir::makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel,
102                                 llvm::TargetMachine *targetMachine) {
103   return [optLevel, sizeLevel, targetMachine](llvm::Module *m) -> llvm::Error {
104     llvm::legacy::PassManager modulePM;
105     llvm::legacy::FunctionPassManager funcPM(m);
106     populatePassManagers(modulePM, funcPM, optLevel, sizeLevel, targetMachine);
107     runPasses(modulePM, funcPM, *m);
108 
109     return llvm::Error::success();
110   };
111 }
112 
113 // Create and return a lambda that is given a set of passes to run, plus an
114 // optional optimization level to pre-populate the pass manager.
115 std::function<llvm::Error(llvm::Module *)> mlir::makeLLVMPassesTransformer(
116     llvm::ArrayRef<const llvm::PassInfo *> llvmPasses,
117     llvm::Optional<unsigned> mbOptLevel, llvm::TargetMachine *targetMachine,
118     unsigned optPassesInsertPos) {
119   return [llvmPasses, mbOptLevel, optPassesInsertPos,
120           targetMachine](llvm::Module *m) -> llvm::Error {
121     llvm::legacy::PassManager modulePM;
122     llvm::legacy::FunctionPassManager funcPM(m);
123 
124     bool insertOptPasses = mbOptLevel.hasValue();
125     for (unsigned i = 0, e = llvmPasses.size(); i < e; ++i) {
126       const auto *passInfo = llvmPasses[i];
127       if (!passInfo->getNormalCtor())
128         continue;
129 
130       if (insertOptPasses && optPassesInsertPos == i) {
131         populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0,
132                              targetMachine);
133         insertOptPasses = false;
134       }
135 
136       auto *pass = passInfo->createPass();
137       if (!pass)
138         return llvm::make_error<llvm::StringError>(
139             "could not create pass " + passInfo->getPassName(),
140             llvm::inconvertibleErrorCode());
141       modulePM.add(pass);
142     }
143 
144     if (insertOptPasses)
145       populatePassManagers(modulePM, funcPM, mbOptLevel.getValue(), 0,
146                            targetMachine);
147 
148     runPasses(modulePM, funcPM, *m);
149     return llvm::Error::success();
150   };
151 }
152