1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This is a library that provides a shared implementation for command line 10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM 11 // IR before JIT-compiling and executing the latter. 12 // 13 // The translation can be customized by providing an MLIR to MLIR 14 // transformation. 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/ExecutionEngine/JitRunner.h" 18 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/ExecutionEngine/ExecutionEngine.h" 21 #include "mlir/ExecutionEngine/OptUtils.h" 22 #include "mlir/IR/MLIRContext.h" 23 #include "mlir/IR/Module.h" 24 #include "mlir/IR/StandardTypes.h" 25 #include "mlir/Parser.h" 26 #include "mlir/Support/FileUtilities.h" 27 28 #include "llvm/ADT/STLExtras.h" 29 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" 30 #include "llvm/IR/IRBuilder.h" 31 #include "llvm/IR/LLVMContext.h" 32 #include "llvm/IR/LegacyPassNameParser.h" 33 #include "llvm/IR/Module.h" 34 #include "llvm/Support/CommandLine.h" 35 #include "llvm/Support/FileUtilities.h" 36 #include "llvm/Support/SourceMgr.h" 37 #include "llvm/Support/StringSaver.h" 38 #include "llvm/Support/ToolOutputFile.h" 39 #include <cstdint> 40 #include <numeric> 41 42 using namespace mlir; 43 using llvm::Error; 44 45 namespace { 46 /// This options struct prevents the need for global static initializers, and 47 /// is only initialized if the JITRunner is invoked. 48 struct Options { 49 llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional, 50 llvm::cl::desc("<input file>"), 51 llvm::cl::init("-")}; 52 llvm::cl::opt<std::string> mainFuncName{ 53 "e", llvm::cl::desc("The function to be called"), 54 llvm::cl::value_desc("<function name>"), llvm::cl::init("main")}; 55 llvm::cl::opt<std::string> mainFuncType{ 56 "entry-point-result", 57 llvm::cl::desc("Textual description of the function type to be called"), 58 llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")}; 59 60 llvm::cl::OptionCategory optFlags{"opt-like flags"}; 61 62 // CLI list of pass information 63 llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser> llvmPasses{ 64 llvm::cl::desc("LLVM optimizing passes to run"), llvm::cl::cat(optFlags)}; 65 66 // CLI variables for -On options. 67 llvm::cl::opt<bool> optO0{"O0", 68 llvm::cl::desc("Run opt passes and codegen at O0"), 69 llvm::cl::cat(optFlags)}; 70 llvm::cl::opt<bool> optO1{"O1", 71 llvm::cl::desc("Run opt passes and codegen at O1"), 72 llvm::cl::cat(optFlags)}; 73 llvm::cl::opt<bool> optO2{"O2", 74 llvm::cl::desc("Run opt passes and codegen at O2"), 75 llvm::cl::cat(optFlags)}; 76 llvm::cl::opt<bool> optO3{"O3", 77 llvm::cl::desc("Run opt passes and codegen at O3"), 78 llvm::cl::cat(optFlags)}; 79 80 llvm::cl::OptionCategory clOptionsCategory{"linking options"}; 81 llvm::cl::list<std::string> clSharedLibs{ 82 "shared-libs", llvm::cl::desc("Libraries to link dynamically"), 83 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, 84 llvm::cl::cat(clOptionsCategory)}; 85 86 /// CLI variables for debugging. 87 llvm::cl::opt<bool> dumpObjectFile{ 88 "dump-object-file", 89 llvm::cl::desc("Dump JITted-compiled object to file specified with " 90 "-object-filename (<input file>.o by default).")}; 91 92 llvm::cl::opt<std::string> objectFilename{ 93 "object-filename", 94 llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")}; 95 }; 96 } // end anonymous namespace 97 98 static OwningModuleRef parseMLIRInput(StringRef inputFilename, 99 MLIRContext *context) { 100 // Set up the input file. 101 std::string errorMessage; 102 auto file = openInputFile(inputFilename, &errorMessage); 103 if (!file) { 104 llvm::errs() << errorMessage << "\n"; 105 return nullptr; 106 } 107 108 llvm::SourceMgr sourceMgr; 109 sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc()); 110 return OwningModuleRef(parseSourceFile(sourceMgr, context)); 111 } 112 113 static inline Error make_string_error(const Twine &message) { 114 return llvm::make_error<llvm::StringError>(message.str(), 115 llvm::inconvertibleErrorCode()); 116 } 117 118 static Optional<unsigned> getCommandLineOptLevel(Options &options) { 119 Optional<unsigned> optLevel; 120 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ 121 options.optO0, options.optO1, options.optO2, options.optO3}; 122 123 // Determine if there is an optimization flag present. 124 for (unsigned j = 0; j < 4; ++j) { 125 auto &flag = optFlags[j].get(); 126 if (flag) { 127 optLevel = j; 128 break; 129 } 130 } 131 return optLevel; 132 } 133 134 // JIT-compile the given module and run "entryPoint" with "args" as arguments. 135 static Error 136 compileAndExecute(Options &options, ModuleOp module, StringRef entryPoint, 137 std::function<llvm::Error(llvm::Module *)> transformer, 138 void **args) { 139 Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel; 140 if (auto clOptLevel = getCommandLineOptLevel(options)) 141 jitCodeGenOptLevel = 142 static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue()); 143 SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(), 144 options.clSharedLibs.end()); 145 auto expectedEngine = mlir::ExecutionEngine::create(module, transformer, 146 jitCodeGenOptLevel, libs); 147 if (!expectedEngine) 148 return expectedEngine.takeError(); 149 150 auto engine = std::move(*expectedEngine); 151 auto expectedFPtr = engine->lookup(entryPoint); 152 if (!expectedFPtr) 153 return expectedFPtr.takeError(); 154 155 if (options.dumpObjectFile) 156 engine->dumpToObjectFile(options.objectFilename.empty() 157 ? options.inputFilename + ".o" 158 : options.objectFilename); 159 160 void (*fptr)(void **) = *expectedFPtr; 161 (*fptr)(args); 162 163 return Error::success(); 164 } 165 166 static Error compileAndExecuteVoidFunction( 167 Options &options, ModuleOp module, StringRef entryPoint, 168 std::function<llvm::Error(llvm::Module *)> transformer) { 169 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); 170 if (!mainFunction || mainFunction.empty()) 171 return make_string_error("entry point not found"); 172 void *empty = nullptr; 173 return compileAndExecute(options, module, entryPoint, transformer, &empty); 174 } 175 176 template <typename Type> 177 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction); 178 template <> 179 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) { 180 if (!mainFunction.getType().getFunctionResultType().isIntegerTy(32)) 181 return make_string_error("only single llvm.i32 function result supported"); 182 return Error::success(); 183 } 184 template <> 185 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) { 186 if (!mainFunction.getType().getFunctionResultType().isIntegerTy(64)) 187 return make_string_error("only single llvm.i64 function result supported"); 188 return Error::success(); 189 } 190 template <> 191 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) { 192 if (!mainFunction.getType().getFunctionResultType().isFloatTy()) 193 return make_string_error("only single llvm.f32 function result supported"); 194 return Error::success(); 195 } 196 template <typename Type> 197 Error compileAndExecuteSingleReturnFunction( 198 Options &options, ModuleOp module, StringRef entryPoint, 199 std::function<llvm::Error(llvm::Module *)> transformer) { 200 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); 201 if (!mainFunction || mainFunction.isExternal()) 202 return make_string_error("entry point not found"); 203 204 if (mainFunction.getType().getFunctionNumParams() != 0) 205 return make_string_error("function inputs not supported"); 206 207 if (Error error = checkCompatibleReturnType<Type>(mainFunction)) 208 return error; 209 210 Type res; 211 struct { 212 void *data; 213 } data; 214 data.data = &res; 215 if (auto error = compileAndExecute(options, module, entryPoint, transformer, 216 (void **)&data)) 217 return error; 218 219 // Intentional printing of the output so we can test. 220 llvm::outs() << res << '\n'; 221 222 return Error::success(); 223 } 224 225 /// Entry point for all CPU runners. Expects the common argc/argv 226 /// arguments for standard C++ main functions and an mlirTransformer. 227 /// The latter is applied after parsing the input into MLIR IR and 228 /// before passing the MLIR module to the ExecutionEngine. 229 int mlir::JitRunnerMain( 230 int argc, char **argv, 231 function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer) { 232 // Create the options struct containing the command line options for the 233 // runner. This must come before the command line options are parsed. 234 Options options; 235 llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n"); 236 237 Optional<unsigned> optLevel = getCommandLineOptLevel(options); 238 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ 239 options.optO0, options.optO1, options.optO2, options.optO3}; 240 unsigned optCLIPosition = 0; 241 // Determine if there is an optimization flag present, and its CLI position 242 // (optCLIPosition). 243 for (unsigned j = 0; j < 4; ++j) { 244 auto &flag = optFlags[j].get(); 245 if (flag) { 246 optCLIPosition = flag.getPosition(); 247 break; 248 } 249 } 250 // Generate vector of pass information, plus the index at which we should 251 // insert any optimization passes in that vector (optPosition). 252 SmallVector<const llvm::PassInfo *, 4> passes; 253 unsigned optPosition = 0; 254 for (unsigned i = 0, e = options.llvmPasses.size(); i < e; ++i) { 255 passes.push_back(options.llvmPasses[i]); 256 if (optCLIPosition < options.llvmPasses.getPosition(i)) { 257 optPosition = i; 258 optCLIPosition = UINT_MAX; // To ensure we never insert again 259 } 260 } 261 262 MLIRContext context; 263 auto m = parseMLIRInput(options.inputFilename, &context); 264 if (!m) { 265 llvm::errs() << "could not parse the input IR\n"; 266 return 1; 267 } 268 269 if (mlirTransformer) 270 if (failed(mlirTransformer(m.get()))) 271 return EXIT_FAILURE; 272 273 auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); 274 if (!tmBuilderOrError) { 275 llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n"; 276 return EXIT_FAILURE; 277 } 278 auto tmOrError = tmBuilderOrError->createTargetMachine(); 279 if (!tmOrError) { 280 llvm::errs() << "Failed to create a TargetMachine for the host\n"; 281 return EXIT_FAILURE; 282 } 283 284 auto transformer = mlir::makeLLVMPassesTransformer( 285 passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition); 286 287 // Get the function used to compile and execute the module. 288 using CompileAndExecuteFnT = 289 Error (*)(Options &, ModuleOp, StringRef, 290 std::function<llvm::Error(llvm::Module *)>); 291 auto compileAndExecuteFn = 292 llvm::StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue()) 293 .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>) 294 .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>) 295 .Case("f32", compileAndExecuteSingleReturnFunction<float>) 296 .Case("void", compileAndExecuteVoidFunction) 297 .Default(nullptr); 298 299 Error error = 300 compileAndExecuteFn 301 ? compileAndExecuteFn(options, m.get(), 302 options.mainFuncName.getValue(), transformer) 303 : make_string_error("unsupported function type"); 304 305 int exitCode = EXIT_SUCCESS; 306 llvm::handleAllErrors(std::move(error), 307 [&exitCode](const llvm::ErrorInfoBase &info) { 308 llvm::errs() << "Error: "; 309 info.log(llvm::errs()); 310 llvm::errs() << '\n'; 311 exitCode = EXIT_FAILURE; 312 }); 313 314 return exitCode; 315 } 316