1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This is a library that provides a shared implementation for command line 10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM 11 // IR before JIT-compiling and executing the latter. 12 // 13 // The translation can be customized by providing an MLIR to MLIR 14 // transformation. 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/ExecutionEngine/JitRunner.h" 18 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/ExecutionEngine/ExecutionEngine.h" 21 #include "mlir/ExecutionEngine/OptUtils.h" 22 #include "mlir/IR/MLIRContext.h" 23 #include "mlir/IR/Module.h" 24 #include "mlir/IR/StandardTypes.h" 25 #include "mlir/InitAllDialects.h" 26 #include "mlir/Parser.h" 27 #include "mlir/Support/FileUtilities.h" 28 29 #include "llvm/ADT/STLExtras.h" 30 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" 31 #include "llvm/IR/IRBuilder.h" 32 #include "llvm/IR/LLVMContext.h" 33 #include "llvm/IR/LegacyPassNameParser.h" 34 #include "llvm/IR/Module.h" 35 #include "llvm/Support/CommandLine.h" 36 #include "llvm/Support/FileUtilities.h" 37 #include "llvm/Support/SourceMgr.h" 38 #include "llvm/Support/StringSaver.h" 39 #include "llvm/Support/ToolOutputFile.h" 40 #include <cstdint> 41 #include <numeric> 42 43 using namespace mlir; 44 using llvm::Error; 45 46 namespace { 47 /// This options struct prevents the need for global static initializers, and 48 /// is only initialized if the JITRunner is invoked. 49 struct Options { 50 llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional, 51 llvm::cl::desc("<input file>"), 52 llvm::cl::init("-")}; 53 llvm::cl::opt<std::string> mainFuncName{ 54 "e", llvm::cl::desc("The function to be called"), 55 llvm::cl::value_desc("<function name>"), llvm::cl::init("main")}; 56 llvm::cl::opt<std::string> mainFuncType{ 57 "entry-point-result", 58 llvm::cl::desc("Textual description of the function type to be called"), 59 llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")}; 60 61 llvm::cl::OptionCategory optFlags{"opt-like flags"}; 62 63 // CLI list of pass information 64 llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser> llvmPasses{ 65 llvm::cl::desc("LLVM optimizing passes to run"), llvm::cl::cat(optFlags)}; 66 67 // CLI variables for -On options. 68 llvm::cl::opt<bool> optO0{"O0", 69 llvm::cl::desc("Run opt passes and codegen at O0"), 70 llvm::cl::cat(optFlags)}; 71 llvm::cl::opt<bool> optO1{"O1", 72 llvm::cl::desc("Run opt passes and codegen at O1"), 73 llvm::cl::cat(optFlags)}; 74 llvm::cl::opt<bool> optO2{"O2", 75 llvm::cl::desc("Run opt passes and codegen at O2"), 76 llvm::cl::cat(optFlags)}; 77 llvm::cl::opt<bool> optO3{"O3", 78 llvm::cl::desc("Run opt passes and codegen at O3"), 79 llvm::cl::cat(optFlags)}; 80 81 llvm::cl::OptionCategory clOptionsCategory{"linking options"}; 82 llvm::cl::list<std::string> clSharedLibs{ 83 "shared-libs", llvm::cl::desc("Libraries to link dynamically"), 84 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, 85 llvm::cl::cat(clOptionsCategory)}; 86 87 /// CLI variables for debugging. 88 llvm::cl::opt<bool> dumpObjectFile{ 89 "dump-object-file", 90 llvm::cl::desc("Dump JITted-compiled object to file specified with " 91 "-object-filename (<input file>.o by default).")}; 92 93 llvm::cl::opt<std::string> objectFilename{ 94 "object-filename", 95 llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")}; 96 }; 97 } // end anonymous namespace 98 99 static OwningModuleRef parseMLIRInput(StringRef inputFilename, 100 MLIRContext *context) { 101 // Set up the input file. 102 std::string errorMessage; 103 auto file = openInputFile(inputFilename, &errorMessage); 104 if (!file) { 105 llvm::errs() << errorMessage << "\n"; 106 return nullptr; 107 } 108 109 llvm::SourceMgr sourceMgr; 110 sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc()); 111 return OwningModuleRef(parseSourceFile(sourceMgr, context)); 112 } 113 114 static inline Error make_string_error(const Twine &message) { 115 return llvm::make_error<llvm::StringError>(message.str(), 116 llvm::inconvertibleErrorCode()); 117 } 118 119 static Optional<unsigned> getCommandLineOptLevel(Options &options) { 120 Optional<unsigned> optLevel; 121 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ 122 options.optO0, options.optO1, options.optO2, options.optO3}; 123 124 // Determine if there is an optimization flag present. 125 for (unsigned j = 0; j < 4; ++j) { 126 auto &flag = optFlags[j].get(); 127 if (flag) { 128 optLevel = j; 129 break; 130 } 131 } 132 return optLevel; 133 } 134 135 // JIT-compile the given module and run "entryPoint" with "args" as arguments. 136 static Error 137 compileAndExecute(Options &options, ModuleOp module, StringRef entryPoint, 138 std::function<llvm::Error(llvm::Module *)> transformer, 139 void **args) { 140 Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel; 141 if (auto clOptLevel = getCommandLineOptLevel(options)) 142 jitCodeGenOptLevel = 143 static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue()); 144 SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(), 145 options.clSharedLibs.end()); 146 auto expectedEngine = mlir::ExecutionEngine::create(module, transformer, 147 jitCodeGenOptLevel, libs); 148 if (!expectedEngine) 149 return expectedEngine.takeError(); 150 151 auto engine = std::move(*expectedEngine); 152 auto expectedFPtr = engine->lookup(entryPoint); 153 if (!expectedFPtr) 154 return expectedFPtr.takeError(); 155 156 if (options.dumpObjectFile) 157 engine->dumpToObjectFile(options.objectFilename.empty() 158 ? options.inputFilename + ".o" 159 : options.objectFilename); 160 161 void (*fptr)(void **) = *expectedFPtr; 162 (*fptr)(args); 163 164 return Error::success(); 165 } 166 167 static Error compileAndExecuteVoidFunction( 168 Options &options, ModuleOp module, StringRef entryPoint, 169 std::function<llvm::Error(llvm::Module *)> transformer) { 170 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); 171 if (!mainFunction || mainFunction.empty()) 172 return make_string_error("entry point not found"); 173 void *empty = nullptr; 174 return compileAndExecute(options, module, entryPoint, transformer, &empty); 175 } 176 177 template <typename Type> 178 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction); 179 template <> 180 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) { 181 if (!mainFunction.getType().getFunctionResultType().isIntegerTy(32)) 182 return make_string_error("only single llvm.i32 function result supported"); 183 return Error::success(); 184 } 185 template <> 186 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) { 187 if (!mainFunction.getType().getFunctionResultType().isIntegerTy(64)) 188 return make_string_error("only single llvm.i64 function result supported"); 189 return Error::success(); 190 } 191 template <> 192 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) { 193 if (!mainFunction.getType().getFunctionResultType().isFloatTy()) 194 return make_string_error("only single llvm.f32 function result supported"); 195 return Error::success(); 196 } 197 template <typename Type> 198 Error compileAndExecuteSingleReturnFunction( 199 Options &options, ModuleOp module, StringRef entryPoint, 200 std::function<llvm::Error(llvm::Module *)> transformer) { 201 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); 202 if (!mainFunction || mainFunction.isExternal()) 203 return make_string_error("entry point not found"); 204 205 if (mainFunction.getType().getFunctionNumParams() != 0) 206 return make_string_error("function inputs not supported"); 207 208 if (Error error = checkCompatibleReturnType<Type>(mainFunction)) 209 return error; 210 211 Type res; 212 struct { 213 void *data; 214 } data; 215 data.data = &res; 216 if (auto error = compileAndExecute(options, module, entryPoint, transformer, 217 (void **)&data)) 218 return error; 219 220 // Intentional printing of the output so we can test. 221 llvm::outs() << res << '\n'; 222 223 return Error::success(); 224 } 225 226 /// Entry point for all CPU runners. Expects the common argc/argv 227 /// arguments for standard C++ main functions and an mlirTransformer. 228 /// The latter is applied after parsing the input into MLIR IR and 229 /// before passing the MLIR module to the ExecutionEngine. 230 int mlir::JitRunnerMain( 231 int argc, char **argv, 232 function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer) { 233 // Create the options struct containing the command line options for the 234 // runner. This must come before the command line options are parsed. 235 Options options; 236 llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n"); 237 238 Optional<unsigned> optLevel = getCommandLineOptLevel(options); 239 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ 240 options.optO0, options.optO1, options.optO2, options.optO3}; 241 unsigned optCLIPosition = 0; 242 // Determine if there is an optimization flag present, and its CLI position 243 // (optCLIPosition). 244 for (unsigned j = 0; j < 4; ++j) { 245 auto &flag = optFlags[j].get(); 246 if (flag) { 247 optCLIPosition = flag.getPosition(); 248 break; 249 } 250 } 251 // Generate vector of pass information, plus the index at which we should 252 // insert any optimization passes in that vector (optPosition). 253 SmallVector<const llvm::PassInfo *, 4> passes; 254 unsigned optPosition = 0; 255 for (unsigned i = 0, e = options.llvmPasses.size(); i < e; ++i) { 256 passes.push_back(options.llvmPasses[i]); 257 if (optCLIPosition < options.llvmPasses.getPosition(i)) { 258 optPosition = i; 259 optCLIPosition = UINT_MAX; // To ensure we never insert again 260 } 261 } 262 263 MLIRContext context; 264 registerAllDialects(context.getDialectRegistry()); 265 266 auto m = parseMLIRInput(options.inputFilename, &context); 267 if (!m) { 268 llvm::errs() << "could not parse the input IR\n"; 269 return 1; 270 } 271 272 if (mlirTransformer) 273 if (failed(mlirTransformer(m.get()))) 274 return EXIT_FAILURE; 275 276 auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); 277 if (!tmBuilderOrError) { 278 llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n"; 279 return EXIT_FAILURE; 280 } 281 auto tmOrError = tmBuilderOrError->createTargetMachine(); 282 if (!tmOrError) { 283 llvm::errs() << "Failed to create a TargetMachine for the host\n"; 284 return EXIT_FAILURE; 285 } 286 287 auto transformer = mlir::makeLLVMPassesTransformer( 288 passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition); 289 290 // Get the function used to compile and execute the module. 291 using CompileAndExecuteFnT = 292 Error (*)(Options &, ModuleOp, StringRef, 293 std::function<llvm::Error(llvm::Module *)>); 294 auto compileAndExecuteFn = 295 StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue()) 296 .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>) 297 .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>) 298 .Case("f32", compileAndExecuteSingleReturnFunction<float>) 299 .Case("void", compileAndExecuteVoidFunction) 300 .Default(nullptr); 301 302 Error error = 303 compileAndExecuteFn 304 ? compileAndExecuteFn(options, m.get(), 305 options.mainFuncName.getValue(), transformer) 306 : make_string_error("unsupported function type"); 307 308 int exitCode = EXIT_SUCCESS; 309 llvm::handleAllErrors(std::move(error), 310 [&exitCode](const llvm::ErrorInfoBase &info) { 311 llvm::errs() << "Error: "; 312 info.log(llvm::errs()); 313 llvm::errs() << '\n'; 314 exitCode = EXIT_FAILURE; 315 }); 316 317 return exitCode; 318 } 319