1 //===- OptUtils.cpp - MLIR Execution Engine optimization pass utilities ---===// 2 // 3 // Copyright 2019 The MLIR Authors. 4 // 5 // Licensed under the Apache License, Version 2.0 (the "License"); 6 // you may not use this file except in compliance with the License. 7 // You may obtain a copy of the License at 8 // 9 // http://www.apache.org/licenses/LICENSE-2.0 10 // 11 // Unless required by applicable law or agreed to in writing, software 12 // distributed under the License is distributed on an "AS IS" BASIS, 13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 // See the License for the specific language governing permissions and 15 // limitations under the License. 16 // ============================================================================= 17 // 18 // This file implements the utility functions to trigger LLVM optimizations from 19 // MLIR Execution Engine. 20 // 21 //===----------------------------------------------------------------------===// 22 23 #include "mlir/ExecutionEngine/OptUtils.h" 24 25 #include "llvm/IR/LegacyPassManager.h" 26 #include "llvm/IR/LegacyPassNameParser.h" 27 #include "llvm/IR/Module.h" 28 #include "llvm/InitializePasses.h" 29 #include "llvm/Pass.h" 30 #include "llvm/Support/Allocator.h" 31 #include "llvm/Support/CommandLine.h" 32 #include "llvm/Support/Error.h" 33 #include "llvm/Support/StringSaver.h" 34 #include "llvm/Transforms/IPO.h" 35 #include "llvm/Transforms/IPO/PassManagerBuilder.h" 36 #include <mutex> 37 38 // Run the module and function passes managed by the module manager. 39 static void runPasses(llvm::legacy::PassManager &modulePM, 40 llvm::legacy::FunctionPassManager &funcPM, 41 llvm::Module &m) { 42 for (auto &func : m) { 43 funcPM.run(func); 44 } 45 modulePM.run(m); 46 } 47 48 // Initialize basic LLVM transformation passes under lock. 49 void mlir::initializeLLVMPasses() { 50 static std::mutex mutex; 51 std::lock_guard<std::mutex> lock(mutex); 52 53 auto ®istry = *llvm::PassRegistry::getPassRegistry(); 54 llvm::initializeCore(registry); 55 llvm::initializeTransformUtils(registry); 56 llvm::initializeScalarOpts(registry); 57 llvm::initializeIPO(registry); 58 llvm::initializeInstCombine(registry); 59 llvm::initializeAggressiveInstCombine(registry); 60 llvm::initializeAnalysis(registry); 61 llvm::initializeVectorization(registry); 62 } 63 64 // Create and return a lambda that uses LLVM pass manager builder to set up 65 // optimizations based on the given level. 66 std::function<llvm::Error(llvm::Module *)> 67 mlir::makeOptimizingTransformer(unsigned optLevel, unsigned sizeLevel) { 68 return [optLevel, sizeLevel](llvm::Module *m) -> llvm::Error { 69 llvm::PassManagerBuilder builder; 70 builder.OptLevel = optLevel; 71 builder.SizeLevel = sizeLevel; 72 builder.Inliner = llvm::createFunctionInliningPass( 73 optLevel, sizeLevel, /*DisableInlineHotCallSite=*/false); 74 75 llvm::legacy::PassManager modulePM; 76 llvm::legacy::FunctionPassManager funcPM(m); 77 builder.populateModulePassManager(modulePM); 78 builder.populateFunctionPassManager(funcPM); 79 runPasses(modulePM, funcPM, *m); 80 81 return llvm::Error::success(); 82 }; 83 } 84 85 // Create and return a lambda that leverages LLVM PassInfo command line parser 86 // to construct passes given the command line flags that come from the given 87 // string rather than from the command line. 88 std::function<llvm::Error(llvm::Module *)> 89 mlir::makeLLVMPassesTransformer(std::string config) { 90 return [config](llvm::Module *m) -> llvm::Error { 91 static llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser> 92 llvmPasses(llvm::cl::desc("LLVM optimizing passes to run")); 93 llvm::BumpPtrAllocator allocator; 94 llvm::StringSaver saver(allocator); 95 llvm::SmallVector<const char *, 16> args; 96 args.push_back(""); // inject dummy program name 97 llvm::cl::TokenizeGNUCommandLine(config, saver, args); 98 llvm::cl::ParseCommandLineOptions(args.size(), args.data()); 99 100 llvm::legacy::PassManager modulePM; 101 102 for (const auto *passInfo : llvmPasses) { 103 if (!passInfo->getNormalCtor()) 104 continue; 105 106 auto *pass = passInfo->createPass(); 107 if (!pass) 108 return llvm::make_error<llvm::StringError>( 109 "could not create pass " + passInfo->getPassName(), 110 llvm::inconvertibleErrorCode()); 111 112 modulePM.add(pass); 113 } 114 115 modulePM.run(*m); 116 return llvm::Error::success(); 117 }; 118 } 119