1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This is a library that provides a shared implementation for command line 10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM 11 // IR before JIT-compiling and executing the latter. 12 // 13 // The translation can be customized by providing an MLIR to MLIR 14 // transformation. 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/ExecutionEngine/JitRunner.h" 18 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/ExecutionEngine/ExecutionEngine.h" 21 #include "mlir/ExecutionEngine/OptUtils.h" 22 #include "mlir/IR/MLIRContext.h" 23 #include "mlir/IR/StandardTypes.h" 24 #include "mlir/InitAllDialects.h" 25 #include "mlir/Parser.h" 26 #include "mlir/Support/FileUtilities.h" 27 28 #include "llvm/ADT/STLExtras.h" 29 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" 30 #include "llvm/IR/IRBuilder.h" 31 #include "llvm/IR/LLVMContext.h" 32 #include "llvm/IR/LegacyPassNameParser.h" 33 #include "llvm/Support/CommandLine.h" 34 #include "llvm/Support/FileUtilities.h" 35 #include "llvm/Support/SourceMgr.h" 36 #include "llvm/Support/StringSaver.h" 37 #include "llvm/Support/ToolOutputFile.h" 38 #include <cstdint> 39 #include <numeric> 40 41 using namespace mlir; 42 using llvm::Error; 43 44 namespace { 45 /// This options struct prevents the need for global static initializers, and 46 /// is only initialized if the JITRunner is invoked. 47 struct Options { 48 llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional, 49 llvm::cl::desc("<input file>"), 50 llvm::cl::init("-")}; 51 llvm::cl::opt<std::string> mainFuncName{ 52 "e", llvm::cl::desc("The function to be called"), 53 llvm::cl::value_desc("<function name>"), llvm::cl::init("main")}; 54 llvm::cl::opt<std::string> mainFuncType{ 55 "entry-point-result", 56 llvm::cl::desc("Textual description of the function type to be called"), 57 llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")}; 58 59 llvm::cl::OptionCategory optFlags{"opt-like flags"}; 60 61 // CLI list of pass information 62 llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser> llvmPasses{ 63 llvm::cl::desc("LLVM optimizing passes to run"), llvm::cl::cat(optFlags)}; 64 65 // CLI variables for -On options. 66 llvm::cl::opt<bool> optO0{"O0", 67 llvm::cl::desc("Run opt passes and codegen at O0"), 68 llvm::cl::cat(optFlags)}; 69 llvm::cl::opt<bool> optO1{"O1", 70 llvm::cl::desc("Run opt passes and codegen at O1"), 71 llvm::cl::cat(optFlags)}; 72 llvm::cl::opt<bool> optO2{"O2", 73 llvm::cl::desc("Run opt passes and codegen at O2"), 74 llvm::cl::cat(optFlags)}; 75 llvm::cl::opt<bool> optO3{"O3", 76 llvm::cl::desc("Run opt passes and codegen at O3"), 77 llvm::cl::cat(optFlags)}; 78 79 llvm::cl::OptionCategory clOptionsCategory{"linking options"}; 80 llvm::cl::list<std::string> clSharedLibs{ 81 "shared-libs", llvm::cl::desc("Libraries to link dynamically"), 82 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, 83 llvm::cl::cat(clOptionsCategory)}; 84 85 /// CLI variables for debugging. 86 llvm::cl::opt<bool> dumpObjectFile{ 87 "dump-object-file", 88 llvm::cl::desc("Dump JITted-compiled object to file specified with " 89 "-object-filename (<input file>.o by default).")}; 90 91 llvm::cl::opt<std::string> objectFilename{ 92 "object-filename", 93 llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")}; 94 }; 95 } // end anonymous namespace 96 97 static OwningModuleRef parseMLIRInput(StringRef inputFilename, 98 MLIRContext *context) { 99 // Set up the input file. 100 std::string errorMessage; 101 auto file = openInputFile(inputFilename, &errorMessage); 102 if (!file) { 103 llvm::errs() << errorMessage << "\n"; 104 return nullptr; 105 } 106 107 llvm::SourceMgr sourceMgr; 108 sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc()); 109 return OwningModuleRef(parseSourceFile(sourceMgr, context)); 110 } 111 112 static inline Error make_string_error(const Twine &message) { 113 return llvm::make_error<llvm::StringError>(message.str(), 114 llvm::inconvertibleErrorCode()); 115 } 116 117 static Optional<unsigned> getCommandLineOptLevel(Options &options) { 118 Optional<unsigned> optLevel; 119 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ 120 options.optO0, options.optO1, options.optO2, options.optO3}; 121 122 // Determine if there is an optimization flag present. 123 for (unsigned j = 0; j < 4; ++j) { 124 auto &flag = optFlags[j].get(); 125 if (flag) { 126 optLevel = j; 127 break; 128 } 129 } 130 return optLevel; 131 } 132 133 // JIT-compile the given module and run "entryPoint" with "args" as arguments. 134 static Error 135 compileAndExecute(Options &options, ModuleOp module, 136 TranslationCallback llvmModuleBuilder, StringRef entryPoint, 137 std::function<llvm::Error(llvm::Module *)> transformer, 138 void **args) { 139 Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel; 140 if (auto clOptLevel = getCommandLineOptLevel(options)) 141 jitCodeGenOptLevel = 142 static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue()); 143 SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(), 144 options.clSharedLibs.end()); 145 auto expectedEngine = mlir::ExecutionEngine::create( 146 module, llvmModuleBuilder, transformer, jitCodeGenOptLevel, libs); 147 if (!expectedEngine) 148 return expectedEngine.takeError(); 149 150 auto engine = std::move(*expectedEngine); 151 auto expectedFPtr = engine->lookup(entryPoint); 152 if (!expectedFPtr) 153 return expectedFPtr.takeError(); 154 155 if (options.dumpObjectFile) 156 engine->dumpToObjectFile(options.objectFilename.empty() 157 ? options.inputFilename + ".o" 158 : options.objectFilename); 159 160 void (*fptr)(void **) = *expectedFPtr; 161 (*fptr)(args); 162 163 return Error::success(); 164 } 165 166 static Error compileAndExecuteVoidFunction( 167 Options &options, ModuleOp module, TranslationCallback llvmModuleBuilder, 168 StringRef entryPoint, 169 std::function<llvm::Error(llvm::Module *)> transformer) { 170 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); 171 if (!mainFunction || mainFunction.empty()) 172 return make_string_error("entry point not found"); 173 void *empty = nullptr; 174 return compileAndExecute(options, module, llvmModuleBuilder, entryPoint, 175 transformer, &empty); 176 } 177 178 template <typename Type> 179 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction); 180 template <> 181 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) { 182 if (!mainFunction.getType().getFunctionResultType().isIntegerTy(32)) 183 return make_string_error("only single llvm.i32 function result supported"); 184 return Error::success(); 185 } 186 template <> 187 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) { 188 if (!mainFunction.getType().getFunctionResultType().isIntegerTy(64)) 189 return make_string_error("only single llvm.i64 function result supported"); 190 return Error::success(); 191 } 192 template <> 193 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) { 194 if (!mainFunction.getType().getFunctionResultType().isFloatTy()) 195 return make_string_error("only single llvm.f32 function result supported"); 196 return Error::success(); 197 } 198 template <typename Type> 199 Error compileAndExecuteSingleReturnFunction( 200 Options &options, ModuleOp module, TranslationCallback llvmModuleBuilder, 201 StringRef entryPoint, 202 std::function<llvm::Error(llvm::Module *)> transformer) { 203 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); 204 if (!mainFunction || mainFunction.isExternal()) 205 return make_string_error("entry point not found"); 206 207 if (mainFunction.getType().getFunctionNumParams() != 0) 208 return make_string_error("function inputs not supported"); 209 210 if (Error error = checkCompatibleReturnType<Type>(mainFunction)) 211 return error; 212 213 Type res; 214 struct { 215 void *data; 216 } data; 217 data.data = &res; 218 if (auto error = compileAndExecute(options, module, llvmModuleBuilder, 219 entryPoint, transformer, (void **)&data)) 220 return error; 221 222 // Intentional printing of the output so we can test. 223 llvm::outs() << res << '\n'; 224 225 return Error::success(); 226 } 227 228 /// Entry point for all CPU runners. Expects the common argc/argv arguments for 229 /// standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`. 230 /// `mlirTransformer` is applied after parsing the input into MLIR IR and before 231 /// passing the MLIR module to the ExecutionEngine. 232 /// `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine. 233 /// It processes MLIR module and creates LLVM IR module. 234 int mlir::JitRunnerMain( 235 int argc, char **argv, 236 function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer, 237 TranslationCallback llvmModuleBuilder) { 238 // Create the options struct containing the command line options for the 239 // runner. This must come before the command line options are parsed. 240 Options options; 241 llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n"); 242 243 Optional<unsigned> optLevel = getCommandLineOptLevel(options); 244 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ 245 options.optO0, options.optO1, options.optO2, options.optO3}; 246 unsigned optCLIPosition = 0; 247 // Determine if there is an optimization flag present, and its CLI position 248 // (optCLIPosition). 249 for (unsigned j = 0; j < 4; ++j) { 250 auto &flag = optFlags[j].get(); 251 if (flag) { 252 optCLIPosition = flag.getPosition(); 253 break; 254 } 255 } 256 // Generate vector of pass information, plus the index at which we should 257 // insert any optimization passes in that vector (optPosition). 258 SmallVector<const llvm::PassInfo *, 4> passes; 259 unsigned optPosition = 0; 260 for (unsigned i = 0, e = options.llvmPasses.size(); i < e; ++i) { 261 passes.push_back(options.llvmPasses[i]); 262 if (optCLIPosition < options.llvmPasses.getPosition(i)) { 263 optPosition = i; 264 optCLIPosition = UINT_MAX; // To ensure we never insert again 265 } 266 } 267 268 MLIRContext context; 269 registerAllDialects(context.getDialectRegistry()); 270 271 auto m = parseMLIRInput(options.inputFilename, &context); 272 if (!m) { 273 llvm::errs() << "could not parse the input IR\n"; 274 return 1; 275 } 276 277 if (mlirTransformer) 278 if (failed(mlirTransformer(m.get()))) 279 return EXIT_FAILURE; 280 281 auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); 282 if (!tmBuilderOrError) { 283 llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n"; 284 return EXIT_FAILURE; 285 } 286 auto tmOrError = tmBuilderOrError->createTargetMachine(); 287 if (!tmOrError) { 288 llvm::errs() << "Failed to create a TargetMachine for the host\n"; 289 return EXIT_FAILURE; 290 } 291 292 auto transformer = mlir::makeLLVMPassesTransformer( 293 passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition); 294 295 // Get the function used to compile and execute the module. 296 using CompileAndExecuteFnT = 297 Error (*)(Options &, ModuleOp, TranslationCallback, StringRef, 298 std::function<llvm::Error(llvm::Module *)>); 299 auto compileAndExecuteFn = 300 StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue()) 301 .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>) 302 .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>) 303 .Case("f32", compileAndExecuteSingleReturnFunction<float>) 304 .Case("void", compileAndExecuteVoidFunction) 305 .Default(nullptr); 306 307 Error error = 308 compileAndExecuteFn 309 ? compileAndExecuteFn(options, m.get(), llvmModuleBuilder, 310 options.mainFuncName.getValue(), transformer) 311 : make_string_error("unsupported function type"); 312 313 int exitCode = EXIT_SUCCESS; 314 llvm::handleAllErrors(std::move(error), 315 [&exitCode](const llvm::ErrorInfoBase &info) { 316 llvm::errs() << "Error: "; 317 info.log(llvm::errs()); 318 llvm::errs() << '\n'; 319 exitCode = EXIT_FAILURE; 320 }); 321 322 return exitCode; 323 } 324