1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This is a library that provides a shared implementation for command line 10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM 11 // IR before JIT-compiling and executing the latter. 12 // 13 // The translation can be customized by providing an MLIR to MLIR 14 // transformation. 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/ExecutionEngine/JitRunner.h" 18 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/ExecutionEngine/ExecutionEngine.h" 21 #include "mlir/ExecutionEngine/OptUtils.h" 22 #include "mlir/IR/MLIRContext.h" 23 #include "mlir/IR/Module.h" 24 #include "mlir/IR/StandardTypes.h" 25 #include "mlir/Parser.h" 26 #include "mlir/Support/FileUtilities.h" 27 28 #include "llvm/ADT/STLExtras.h" 29 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" 30 #include "llvm/IR/IRBuilder.h" 31 #include "llvm/IR/LLVMContext.h" 32 #include "llvm/IR/LegacyPassNameParser.h" 33 #include "llvm/IR/Module.h" 34 #include "llvm/Support/CommandLine.h" 35 #include "llvm/Support/FileUtilities.h" 36 #include "llvm/Support/SourceMgr.h" 37 #include "llvm/Support/StringSaver.h" 38 #include "llvm/Support/ToolOutputFile.h" 39 #include <numeric> 40 41 using namespace mlir; 42 using llvm::Error; 43 44 namespace { 45 /// This options struct prevents the need for global static initializers, and 46 /// is only initialized if the JITRunner is invoked. 47 struct Options { 48 llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional, 49 llvm::cl::desc("<input file>"), 50 llvm::cl::init("-")}; 51 llvm::cl::opt<std::string> mainFuncName{ 52 "e", llvm::cl::desc("The function to be called"), 53 llvm::cl::value_desc("<function name>"), llvm::cl::init("main")}; 54 llvm::cl::opt<std::string> mainFuncType{ 55 "entry-point-result", 56 llvm::cl::desc("Textual description of the function type to be called"), 57 llvm::cl::value_desc("f32 | void"), llvm::cl::init("f32")}; 58 59 llvm::cl::OptionCategory optFlags{"opt-like flags"}; 60 61 // CLI list of pass information 62 llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser> llvmPasses{ 63 llvm::cl::desc("LLVM optimizing passes to run"), llvm::cl::cat(optFlags)}; 64 65 // CLI variables for -On options. 66 llvm::cl::opt<bool> optO0{"O0", 67 llvm::cl::desc("Run opt passes and codegen at O0"), 68 llvm::cl::cat(optFlags)}; 69 llvm::cl::opt<bool> optO1{"O1", 70 llvm::cl::desc("Run opt passes and codegen at O1"), 71 llvm::cl::cat(optFlags)}; 72 llvm::cl::opt<bool> optO2{"O2", 73 llvm::cl::desc("Run opt passes and codegen at O2"), 74 llvm::cl::cat(optFlags)}; 75 llvm::cl::opt<bool> optO3{"O3", 76 llvm::cl::desc("Run opt passes and codegen at O3"), 77 llvm::cl::cat(optFlags)}; 78 79 llvm::cl::OptionCategory clOptionsCategory{"linking options"}; 80 llvm::cl::list<std::string> clSharedLibs{ 81 "shared-libs", llvm::cl::desc("Libraries to link dynamically"), 82 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, 83 llvm::cl::cat(clOptionsCategory)}; 84 85 /// CLI variables for debugging. 86 llvm::cl::opt<bool> dumpObjectFile{ 87 "dump-object-file", 88 llvm::cl::desc("Dump JITted-compiled object to file specified with " 89 "-object-filename (<input file>.o by default).")}; 90 91 llvm::cl::opt<std::string> objectFilename{ 92 "object-filename", 93 llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")}; 94 }; 95 } // end anonymous namespace 96 97 static OwningModuleRef parseMLIRInput(StringRef inputFilename, 98 MLIRContext *context) { 99 // Set up the input file. 100 std::string errorMessage; 101 auto file = openInputFile(inputFilename, &errorMessage); 102 if (!file) { 103 llvm::errs() << errorMessage << "\n"; 104 return nullptr; 105 } 106 107 llvm::SourceMgr sourceMgr; 108 sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc()); 109 return OwningModuleRef(parseSourceFile(sourceMgr, context)); 110 } 111 112 static inline Error make_string_error(const Twine &message) { 113 return llvm::make_error<llvm::StringError>(message.str(), 114 llvm::inconvertibleErrorCode()); 115 } 116 117 static Optional<unsigned> getCommandLineOptLevel(Options &options) { 118 Optional<unsigned> optLevel; 119 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ 120 options.optO0, options.optO1, options.optO2, options.optO3}; 121 122 // Determine if there is an optimization flag present. 123 for (unsigned j = 0; j < 4; ++j) { 124 auto &flag = optFlags[j].get(); 125 if (flag) { 126 optLevel = j; 127 break; 128 } 129 } 130 return optLevel; 131 } 132 133 // JIT-compile the given module and run "entryPoint" with "args" as arguments. 134 static Error 135 compileAndExecute(Options &options, ModuleOp module, StringRef entryPoint, 136 std::function<llvm::Error(llvm::Module *)> transformer, 137 void **args) { 138 Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel; 139 if (auto clOptLevel = getCommandLineOptLevel(options)) 140 jitCodeGenOptLevel = 141 static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue()); 142 SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(), 143 options.clSharedLibs.end()); 144 auto expectedEngine = mlir::ExecutionEngine::create(module, transformer, 145 jitCodeGenOptLevel, libs); 146 if (!expectedEngine) 147 return expectedEngine.takeError(); 148 149 auto engine = std::move(*expectedEngine); 150 auto expectedFPtr = engine->lookup(entryPoint); 151 if (!expectedFPtr) 152 return expectedFPtr.takeError(); 153 154 if (options.dumpObjectFile) 155 engine->dumpToObjectFile(options.objectFilename.empty() 156 ? options.inputFilename + ".o" 157 : options.objectFilename); 158 159 void (*fptr)(void **) = *expectedFPtr; 160 (*fptr)(args); 161 162 return Error::success(); 163 } 164 165 static Error compileAndExecuteVoidFunction( 166 Options &options, ModuleOp module, StringRef entryPoint, 167 std::function<llvm::Error(llvm::Module *)> transformer) { 168 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); 169 if (!mainFunction || mainFunction.getBlocks().empty()) 170 return make_string_error("entry point not found"); 171 void *empty = nullptr; 172 return compileAndExecute(options, module, entryPoint, transformer, &empty); 173 } 174 175 static Error compileAndExecuteSingleFloatReturnFunction( 176 Options &options, ModuleOp module, StringRef entryPoint, 177 std::function<llvm::Error(llvm::Module *)> transformer) { 178 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); 179 if (!mainFunction || mainFunction.isExternal()) 180 return make_string_error("entry point not found"); 181 182 if (mainFunction.getType().getFunctionNumParams() != 0) 183 return make_string_error("function inputs not supported"); 184 185 if (!mainFunction.getType().getFunctionResultType().isFloatTy()) 186 return make_string_error("only single llvm.f32 function result supported"); 187 188 float res; 189 struct { 190 void *data; 191 } data; 192 data.data = &res; 193 if (auto error = compileAndExecute(options, module, entryPoint, transformer, 194 (void **)&data)) 195 return error; 196 197 // Intentional printing of the output so we can test. 198 llvm::outs() << res << '\n'; 199 return Error::success(); 200 } 201 202 /// Entry point for all CPU runners. Expects the common argc/argv arguments for 203 /// standard C++ main functions and an mlirTransformer. 204 /// The latter is applied after parsing the input into MLIR IR and before 205 /// passing the MLIR module to the ExecutionEngine. 206 int mlir::JitRunnerMain( 207 int argc, char **argv, 208 function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer) { 209 // Create the options struct containing the command line options for the 210 // runner. This must come before the command line options are parsed. 211 Options options; 212 llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n"); 213 214 Optional<unsigned> optLevel = getCommandLineOptLevel(options); 215 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ 216 options.optO0, options.optO1, options.optO2, options.optO3}; 217 unsigned optCLIPosition = 0; 218 // Determine if there is an optimization flag present, and its CLI position 219 // (optCLIPosition). 220 for (unsigned j = 0; j < 4; ++j) { 221 auto &flag = optFlags[j].get(); 222 if (flag) { 223 optCLIPosition = flag.getPosition(); 224 break; 225 } 226 } 227 // Generate vector of pass information, plus the index at which we should 228 // insert any optimization passes in that vector (optPosition). 229 SmallVector<const llvm::PassInfo *, 4> passes; 230 unsigned optPosition = 0; 231 for (unsigned i = 0, e = options.llvmPasses.size(); i < e; ++i) { 232 passes.push_back(options.llvmPasses[i]); 233 if (optCLIPosition < options.llvmPasses.getPosition(i)) { 234 optPosition = i; 235 optCLIPosition = UINT_MAX; // To ensure we never insert again 236 } 237 } 238 239 MLIRContext context; 240 auto m = parseMLIRInput(options.inputFilename, &context); 241 if (!m) { 242 llvm::errs() << "could not parse the input IR\n"; 243 return 1; 244 } 245 246 if (mlirTransformer) 247 if (failed(mlirTransformer(m.get()))) 248 return EXIT_FAILURE; 249 250 auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); 251 if (!tmBuilderOrError) { 252 llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n"; 253 return EXIT_FAILURE; 254 } 255 auto tmOrError = tmBuilderOrError->createTargetMachine(); 256 if (!tmOrError) { 257 llvm::errs() << "Failed to create a TargetMachine for the host\n"; 258 return EXIT_FAILURE; 259 } 260 261 auto transformer = mlir::makeLLVMPassesTransformer( 262 passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition); 263 264 // Get the function used to compile and execute the module. 265 using CompileAndExecuteFnT = 266 Error (*)(Options &, ModuleOp, StringRef, 267 std::function<llvm::Error(llvm::Module *)>); 268 auto compileAndExecuteFn = 269 llvm::StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue()) 270 .Case("f32", compileAndExecuteSingleFloatReturnFunction) 271 .Case("void", compileAndExecuteVoidFunction) 272 .Default(nullptr); 273 274 Error error = 275 compileAndExecuteFn 276 ? compileAndExecuteFn(options, m.get(), 277 options.mainFuncName.getValue(), transformer) 278 : make_string_error("unsupported function type"); 279 280 int exitCode = EXIT_SUCCESS; 281 llvm::handleAllErrors(std::move(error), 282 [&exitCode](const llvm::ErrorInfoBase &info) { 283 llvm::errs() << "Error: "; 284 info.log(llvm::errs()); 285 llvm::errs() << '\n'; 286 exitCode = EXIT_FAILURE; 287 }); 288 289 return exitCode; 290 } 291