1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This is a library that provides a shared implementation for command line 10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM 11 // IR before JIT-compiling and executing the latter. 12 // 13 // The translation can be customized by providing an MLIR to MLIR 14 // transformation. 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/ExecutionEngine/JitRunner.h" 18 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/ExecutionEngine/ExecutionEngine.h" 21 #include "mlir/ExecutionEngine/OptUtils.h" 22 #include "mlir/IR/MLIRContext.h" 23 #include "mlir/IR/Module.h" 24 #include "mlir/IR/StandardTypes.h" 25 #include "mlir/Parser.h" 26 #include "mlir/Support/FileUtilities.h" 27 28 #include "llvm/ADT/STLExtras.h" 29 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" 30 #include "llvm/IR/IRBuilder.h" 31 #include "llvm/IR/LLVMContext.h" 32 #include "llvm/IR/LegacyPassNameParser.h" 33 #include "llvm/IR/Module.h" 34 #include "llvm/Support/CommandLine.h" 35 #include "llvm/Support/FileUtilities.h" 36 #include "llvm/Support/SourceMgr.h" 37 #include "llvm/Support/StringSaver.h" 38 #include "llvm/Support/ToolOutputFile.h" 39 #include <cstdint> 40 #include <numeric> 41 42 using namespace mlir; 43 using llvm::Error; 44 45 namespace { 46 /// This options struct prevents the need for global static initializers, and 47 /// is only initialized if the JITRunner is invoked. 48 struct Options { 49 llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional, 50 llvm::cl::desc("<input file>"), 51 llvm::cl::init("-")}; 52 llvm::cl::opt<std::string> mainFuncName{ 53 "e", llvm::cl::desc("The function to be called"), 54 llvm::cl::value_desc("<function name>"), llvm::cl::init("main")}; 55 llvm::cl::opt<std::string> mainFuncType{ 56 "entry-point-result", 57 llvm::cl::desc("Textual description of the function type to be called"), 58 llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")}; 59 60 llvm::cl::OptionCategory optFlags{"opt-like flags"}; 61 62 // CLI list of pass information 63 llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser> llvmPasses{ 64 llvm::cl::desc("LLVM optimizing passes to run"), llvm::cl::cat(optFlags)}; 65 66 // CLI variables for -On options. 67 llvm::cl::opt<bool> optO0{"O0", 68 llvm::cl::desc("Run opt passes and codegen at O0"), 69 llvm::cl::cat(optFlags)}; 70 llvm::cl::opt<bool> optO1{"O1", 71 llvm::cl::desc("Run opt passes and codegen at O1"), 72 llvm::cl::cat(optFlags)}; 73 llvm::cl::opt<bool> optO2{"O2", 74 llvm::cl::desc("Run opt passes and codegen at O2"), 75 llvm::cl::cat(optFlags)}; 76 llvm::cl::opt<bool> optO3{"O3", 77 llvm::cl::desc("Run opt passes and codegen at O3"), 78 llvm::cl::cat(optFlags)}; 79 80 llvm::cl::OptionCategory clOptionsCategory{"linking options"}; 81 llvm::cl::list<std::string> clSharedLibs{ 82 "shared-libs", llvm::cl::desc("Libraries to link dynamically"), 83 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, 84 llvm::cl::cat(clOptionsCategory)}; 85 86 /// CLI variables for debugging. 87 llvm::cl::opt<bool> dumpObjectFile{ 88 "dump-object-file", 89 llvm::cl::desc("Dump JITted-compiled object to file specified with " 90 "-object-filename (<input file>.o by default).")}; 91 92 llvm::cl::opt<std::string> objectFilename{ 93 "object-filename", 94 llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")}; 95 }; 96 } // end anonymous namespace 97 98 static OwningModuleRef parseMLIRInput(StringRef inputFilename, 99 MLIRContext *context) { 100 // Set up the input file. 101 std::string errorMessage; 102 auto file = openInputFile(inputFilename, &errorMessage); 103 if (!file) { 104 llvm::errs() << errorMessage << "\n"; 105 return nullptr; 106 } 107 108 llvm::SourceMgr sourceMgr; 109 sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc()); 110 return OwningModuleRef(parseSourceFile(sourceMgr, context)); 111 } 112 113 static inline Error make_string_error(const Twine &message) { 114 return llvm::make_error<llvm::StringError>(message.str(), 115 llvm::inconvertibleErrorCode()); 116 } 117 118 static Optional<unsigned> getCommandLineOptLevel(Options &options) { 119 Optional<unsigned> optLevel; 120 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ 121 options.optO0, options.optO1, options.optO2, options.optO3}; 122 123 // Determine if there is an optimization flag present. 124 for (unsigned j = 0; j < 4; ++j) { 125 auto &flag = optFlags[j].get(); 126 if (flag) { 127 optLevel = j; 128 break; 129 } 130 } 131 return optLevel; 132 } 133 134 // JIT-compile the given module and run "entryPoint" with "args" as arguments. 135 static Error 136 compileAndExecute(Options &options, ModuleOp module, StringRef entryPoint, 137 std::function<llvm::Error(llvm::Module *)> transformer, 138 void **args) { 139 Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel; 140 if (auto clOptLevel = getCommandLineOptLevel(options)) 141 jitCodeGenOptLevel = 142 static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue()); 143 SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(), 144 options.clSharedLibs.end()); 145 auto expectedEngine = mlir::ExecutionEngine::create(module, transformer, 146 jitCodeGenOptLevel, libs); 147 if (!expectedEngine) 148 return expectedEngine.takeError(); 149 150 auto engine = std::move(*expectedEngine); 151 auto expectedFPtr = engine->lookup(entryPoint); 152 if (!expectedFPtr) 153 return expectedFPtr.takeError(); 154 155 if (options.dumpObjectFile) 156 engine->dumpToObjectFile(options.objectFilename.empty() 157 ? options.inputFilename + ".o" 158 : options.objectFilename); 159 160 void (*fptr)(void **) = *expectedFPtr; 161 (*fptr)(args); 162 163 return Error::success(); 164 } 165 166 static Error compileAndExecuteVoidFunction( 167 Options &options, ModuleOp module, StringRef entryPoint, 168 std::function<llvm::Error(llvm::Module *)> transformer) { 169 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); 170 if (!mainFunction || mainFunction.empty()) 171 return make_string_error("entry point not found"); 172 void *empty = nullptr; 173 return compileAndExecute(options, module, entryPoint, transformer, &empty); 174 } 175 176 template <typename Type> 177 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction); 178 template <> 179 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) { 180 if (!mainFunction.getType().getFunctionResultType().isIntegerTy(32)) 181 return make_string_error("only single llvm.i32 function result supported"); 182 return Error::success(); 183 } 184 template <> 185 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) { 186 if (!mainFunction.getType().getFunctionResultType().isIntegerTy(64)) 187 return make_string_error("only single llvm.i64 function result supported"); 188 return Error::success(); 189 } 190 template <> 191 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) { 192 if (!mainFunction.getType().getFunctionResultType().isFloatTy()) 193 return make_string_error("only single llvm.f32 function result supported"); 194 return Error::success(); 195 } 196 template <typename Type> 197 Error compileAndExecuteSingleReturnFunction( 198 Options &options, ModuleOp module, StringRef entryPoint, 199 std::function<llvm::Error(llvm::Module *)> transformer) { 200 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); 201 if (!mainFunction || mainFunction.isExternal()) 202 return make_string_error("entry point not found"); 203 204 if (mainFunction.getType().getFunctionNumParams() != 0) 205 return make_string_error("function inputs not supported"); 206 207 if (Error error = checkCompatibleReturnType<Type>(mainFunction)) 208 return error; 209 210 Type res; 211 struct { 212 void *data; 213 } data; 214 data.data = &res; 215 if (auto error = compileAndExecute(options, module, entryPoint, transformer, 216 (void **)&data)) 217 return error; 218 219 // Intentional printing of the output so we can test. 220 llvm::outs() << res << '\n'; 221 222 return Error::success(); 223 } 224 225 /// Entry point for all CPU runners. Expects the common argc/argv 226 /// arguments for standard C++ main functions and an mlirTransformer. 227 /// The latter is applied after parsing the input into MLIR IR and 228 /// before passing the MLIR module to the ExecutionEngine. 229 int mlir::JitRunnerMain( 230 int argc, char **argv, 231 function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer) { 232 // Create the options struct containing the command line options for the 233 // runner. This must come before the command line options are parsed. 234 Options options; 235 llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n"); 236 237 Optional<unsigned> optLevel = getCommandLineOptLevel(options); 238 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ 239 options.optO0, options.optO1, options.optO2, options.optO3}; 240 unsigned optCLIPosition = 0; 241 // Determine if there is an optimization flag present, and its CLI position 242 // (optCLIPosition). 243 for (unsigned j = 0; j < 4; ++j) { 244 auto &flag = optFlags[j].get(); 245 if (flag) { 246 optCLIPosition = flag.getPosition(); 247 break; 248 } 249 } 250 // Generate vector of pass information, plus the index at which we should 251 // insert any optimization passes in that vector (optPosition). 252 SmallVector<const llvm::PassInfo *, 4> passes; 253 unsigned optPosition = 0; 254 for (unsigned i = 0, e = options.llvmPasses.size(); i < e; ++i) { 255 passes.push_back(options.llvmPasses[i]); 256 if (optCLIPosition < options.llvmPasses.getPosition(i)) { 257 optPosition = i; 258 optCLIPosition = UINT_MAX; // To ensure we never insert again 259 } 260 } 261 262 MLIRContext context(/*loadAllDialects=*/false); 263 registerAllDialects(&context); 264 265 auto m = parseMLIRInput(options.inputFilename, &context); 266 if (!m) { 267 llvm::errs() << "could not parse the input IR\n"; 268 return 1; 269 } 270 271 if (mlirTransformer) 272 if (failed(mlirTransformer(m.get()))) 273 return EXIT_FAILURE; 274 275 auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); 276 if (!tmBuilderOrError) { 277 llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n"; 278 return EXIT_FAILURE; 279 } 280 auto tmOrError = tmBuilderOrError->createTargetMachine(); 281 if (!tmOrError) { 282 llvm::errs() << "Failed to create a TargetMachine for the host\n"; 283 return EXIT_FAILURE; 284 } 285 286 auto transformer = mlir::makeLLVMPassesTransformer( 287 passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition); 288 289 // Get the function used to compile and execute the module. 290 using CompileAndExecuteFnT = 291 Error (*)(Options &, ModuleOp, StringRef, 292 std::function<llvm::Error(llvm::Module *)>); 293 auto compileAndExecuteFn = 294 llvm::StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue()) 295 .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>) 296 .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>) 297 .Case("f32", compileAndExecuteSingleReturnFunction<float>) 298 .Case("void", compileAndExecuteVoidFunction) 299 .Default(nullptr); 300 301 Error error = 302 compileAndExecuteFn 303 ? compileAndExecuteFn(options, m.get(), 304 options.mainFuncName.getValue(), transformer) 305 : make_string_error("unsupported function type"); 306 307 int exitCode = EXIT_SUCCESS; 308 llvm::handleAllErrors(std::move(error), 309 [&exitCode](const llvm::ErrorInfoBase &info) { 310 llvm::errs() << "Error: "; 311 info.log(llvm::errs()); 312 llvm::errs() << '\n'; 313 exitCode = EXIT_FAILURE; 314 }); 315 316 return exitCode; 317 } 318