1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This is a library that provides a shared implementation for command line 10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM 11 // IR before JIT-compiling and executing the latter. 12 // 13 // The translation can be customized by providing an MLIR to MLIR 14 // transformation. 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/ExecutionEngine/JitRunner.h" 18 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/ExecutionEngine/ExecutionEngine.h" 21 #include "mlir/ExecutionEngine/OptUtils.h" 22 #include "mlir/IR/BuiltinTypes.h" 23 #include "mlir/IR/MLIRContext.h" 24 #include "mlir/InitAllDialects.h" 25 #include "mlir/Parser.h" 26 #include "mlir/Support/FileUtilities.h" 27 28 #include "llvm/ADT/STLExtras.h" 29 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" 30 #include "llvm/IR/IRBuilder.h" 31 #include "llvm/IR/LLVMContext.h" 32 #include "llvm/IR/LegacyPassNameParser.h" 33 #include "llvm/Support/CommandLine.h" 34 #include "llvm/Support/FileUtilities.h" 35 #include "llvm/Support/SourceMgr.h" 36 #include "llvm/Support/StringSaver.h" 37 #include "llvm/Support/ToolOutputFile.h" 38 #include <cstdint> 39 #include <numeric> 40 41 using namespace mlir; 42 using llvm::Error; 43 44 namespace { 45 /// This options struct prevents the need for global static initializers, and 46 /// is only initialized if the JITRunner is invoked. 47 struct Options { 48 llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional, 49 llvm::cl::desc("<input file>"), 50 llvm::cl::init("-")}; 51 llvm::cl::opt<std::string> mainFuncName{ 52 "e", llvm::cl::desc("The function to be called"), 53 llvm::cl::value_desc("<function name>"), llvm::cl::init("main")}; 54 llvm::cl::opt<std::string> mainFuncType{ 55 "entry-point-result", 56 llvm::cl::desc("Textual description of the function type to be called"), 57 llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")}; 58 59 llvm::cl::OptionCategory optFlags{"opt-like flags"}; 60 61 // CLI list of pass information 62 llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser> llvmPasses{ 63 llvm::cl::desc("LLVM optimizing passes to run"), llvm::cl::cat(optFlags)}; 64 65 // CLI variables for -On options. 66 llvm::cl::opt<bool> optO0{"O0", 67 llvm::cl::desc("Run opt passes and codegen at O0"), 68 llvm::cl::cat(optFlags)}; 69 llvm::cl::opt<bool> optO1{"O1", 70 llvm::cl::desc("Run opt passes and codegen at O1"), 71 llvm::cl::cat(optFlags)}; 72 llvm::cl::opt<bool> optO2{"O2", 73 llvm::cl::desc("Run opt passes and codegen at O2"), 74 llvm::cl::cat(optFlags)}; 75 llvm::cl::opt<bool> optO3{"O3", 76 llvm::cl::desc("Run opt passes and codegen at O3"), 77 llvm::cl::cat(optFlags)}; 78 79 llvm::cl::OptionCategory clOptionsCategory{"linking options"}; 80 llvm::cl::list<std::string> clSharedLibs{ 81 "shared-libs", llvm::cl::desc("Libraries to link dynamically"), 82 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, 83 llvm::cl::cat(clOptionsCategory)}; 84 85 /// CLI variables for debugging. 86 llvm::cl::opt<bool> dumpObjectFile{ 87 "dump-object-file", 88 llvm::cl::desc("Dump JITted-compiled object to file specified with " 89 "-object-filename (<input file>.o by default).")}; 90 91 llvm::cl::opt<std::string> objectFilename{ 92 "object-filename", 93 llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")}; 94 }; 95 96 struct CompileAndExecuteConfig { 97 /// LLVM module transformer that is passed to ExecutionEngine. 98 llvm::function_ref<llvm::Error(llvm::Module *)> transformer; 99 100 /// A custom function that is passed to ExecutionEngine. It processes MLIR 101 /// module and creates LLVM IR module. 102 llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp, 103 llvm::LLVMContext &)> 104 llvmModuleBuilder; 105 106 /// A custom function that is passed to ExecutinEngine to register symbols at 107 /// runtime. 108 llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)> 109 runtimeSymbolMap; 110 }; 111 112 } // end anonymous namespace 113 114 static OwningModuleRef parseMLIRInput(StringRef inputFilename, 115 MLIRContext *context) { 116 // Set up the input file. 117 std::string errorMessage; 118 auto file = openInputFile(inputFilename, &errorMessage); 119 if (!file) { 120 llvm::errs() << errorMessage << "\n"; 121 return nullptr; 122 } 123 124 llvm::SourceMgr sourceMgr; 125 sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc()); 126 return OwningModuleRef(parseSourceFile(sourceMgr, context)); 127 } 128 129 static inline Error make_string_error(const Twine &message) { 130 return llvm::make_error<llvm::StringError>(message.str(), 131 llvm::inconvertibleErrorCode()); 132 } 133 134 static Optional<unsigned> getCommandLineOptLevel(Options &options) { 135 Optional<unsigned> optLevel; 136 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ 137 options.optO0, options.optO1, options.optO2, options.optO3}; 138 139 // Determine if there is an optimization flag present. 140 for (unsigned j = 0; j < 4; ++j) { 141 auto &flag = optFlags[j].get(); 142 if (flag) { 143 optLevel = j; 144 break; 145 } 146 } 147 return optLevel; 148 } 149 150 // JIT-compile the given module and run "entryPoint" with "args" as arguments. 151 static Error compileAndExecute(Options &options, ModuleOp module, 152 StringRef entryPoint, 153 CompileAndExecuteConfig config, void **args) { 154 Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel; 155 if (auto clOptLevel = getCommandLineOptLevel(options)) 156 jitCodeGenOptLevel = 157 static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue()); 158 SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(), 159 options.clSharedLibs.end()); 160 auto expectedEngine = mlir::ExecutionEngine::create( 161 module, config.llvmModuleBuilder, config.transformer, jitCodeGenOptLevel, 162 libs); 163 if (!expectedEngine) 164 return expectedEngine.takeError(); 165 166 auto engine = std::move(*expectedEngine); 167 if (config.runtimeSymbolMap) 168 engine->registerSymbols(config.runtimeSymbolMap); 169 170 auto expectedFPtr = engine->lookup(entryPoint); 171 if (!expectedFPtr) 172 return expectedFPtr.takeError(); 173 174 if (options.dumpObjectFile) 175 engine->dumpToObjectFile(options.objectFilename.empty() 176 ? options.inputFilename + ".o" 177 : options.objectFilename); 178 179 void (*fptr)(void **) = *expectedFPtr; 180 (*fptr)(args); 181 182 return Error::success(); 183 } 184 185 static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module, 186 StringRef entryPoint, 187 CompileAndExecuteConfig config) { 188 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); 189 if (!mainFunction || mainFunction.empty()) 190 return make_string_error("entry point not found"); 191 void *empty = nullptr; 192 return compileAndExecute(options, module, entryPoint, config, &empty); 193 } 194 195 template <typename Type> 196 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction); 197 template <> 198 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) { 199 auto resultType = mainFunction.getType() 200 .cast<LLVM::LLVMFunctionType>() 201 .getReturnType() 202 .dyn_cast<LLVM::LLVMIntegerType>(); 203 if (!resultType || resultType.getBitWidth() != 32) 204 return make_string_error("only single llvm.i32 function result supported"); 205 return Error::success(); 206 } 207 template <> 208 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) { 209 auto resultType = mainFunction.getType() 210 .cast<LLVM::LLVMFunctionType>() 211 .getReturnType() 212 .dyn_cast<LLVM::LLVMIntegerType>(); 213 if (!resultType || resultType.getBitWidth() != 64) 214 return make_string_error("only single llvm.i64 function result supported"); 215 return Error::success(); 216 } 217 template <> 218 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) { 219 if (!mainFunction.getType() 220 .cast<LLVM::LLVMFunctionType>() 221 .getReturnType() 222 .isa<LLVM::LLVMFloatType>()) 223 return make_string_error("only single llvm.f32 function result supported"); 224 return Error::success(); 225 } 226 template <typename Type> 227 Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module, 228 StringRef entryPoint, 229 CompileAndExecuteConfig config) { 230 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); 231 if (!mainFunction || mainFunction.isExternal()) 232 return make_string_error("entry point not found"); 233 234 if (mainFunction.getType().cast<LLVM::LLVMFunctionType>().getNumParams() != 0) 235 return make_string_error("function inputs not supported"); 236 237 if (Error error = checkCompatibleReturnType<Type>(mainFunction)) 238 return error; 239 240 Type res; 241 struct { 242 void *data; 243 } data; 244 data.data = &res; 245 if (auto error = compileAndExecute(options, module, entryPoint, config, 246 (void **)&data)) 247 return error; 248 249 // Intentional printing of the output so we can test. 250 llvm::outs() << res << '\n'; 251 252 return Error::success(); 253 } 254 255 /// Entry point for all CPU runners. Expects the common argc/argv arguments for 256 /// standard C++ main functions. 257 int mlir::JitRunnerMain(int argc, char **argv, JitRunnerConfig config) { 258 // Create the options struct containing the command line options for the 259 // runner. This must come before the command line options are parsed. 260 Options options; 261 llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n"); 262 263 Optional<unsigned> optLevel = getCommandLineOptLevel(options); 264 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ 265 options.optO0, options.optO1, options.optO2, options.optO3}; 266 unsigned optCLIPosition = 0; 267 // Determine if there is an optimization flag present, and its CLI position 268 // (optCLIPosition). 269 for (unsigned j = 0; j < 4; ++j) { 270 auto &flag = optFlags[j].get(); 271 if (flag) { 272 optCLIPosition = flag.getPosition(); 273 break; 274 } 275 } 276 // Generate vector of pass information, plus the index at which we should 277 // insert any optimization passes in that vector (optPosition). 278 SmallVector<const llvm::PassInfo *, 4> passes; 279 unsigned optPosition = 0; 280 for (unsigned i = 0, e = options.llvmPasses.size(); i < e; ++i) { 281 passes.push_back(options.llvmPasses[i]); 282 if (optCLIPosition < options.llvmPasses.getPosition(i)) { 283 optPosition = i; 284 optCLIPosition = UINT_MAX; // To ensure we never insert again 285 } 286 } 287 288 MLIRContext context; 289 registerAllDialects(context.getDialectRegistry()); 290 291 auto m = parseMLIRInput(options.inputFilename, &context); 292 if (!m) { 293 llvm::errs() << "could not parse the input IR\n"; 294 return 1; 295 } 296 297 if (config.mlirTransformer) 298 if (failed(config.mlirTransformer(m.get()))) 299 return EXIT_FAILURE; 300 301 auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); 302 if (!tmBuilderOrError) { 303 llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n"; 304 return EXIT_FAILURE; 305 } 306 auto tmOrError = tmBuilderOrError->createTargetMachine(); 307 if (!tmOrError) { 308 llvm::errs() << "Failed to create a TargetMachine for the host\n"; 309 return EXIT_FAILURE; 310 } 311 312 auto transformer = mlir::makeLLVMPassesTransformer( 313 passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition); 314 315 CompileAndExecuteConfig compileAndExecuteConfig; 316 compileAndExecuteConfig.transformer = transformer; 317 compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder; 318 compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap; 319 320 // Get the function used to compile and execute the module. 321 using CompileAndExecuteFnT = 322 Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig); 323 auto compileAndExecuteFn = 324 StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue()) 325 .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>) 326 .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>) 327 .Case("f32", compileAndExecuteSingleReturnFunction<float>) 328 .Case("void", compileAndExecuteVoidFunction) 329 .Default(nullptr); 330 331 Error error = compileAndExecuteFn 332 ? compileAndExecuteFn(options, m.get(), 333 options.mainFuncName.getValue(), 334 compileAndExecuteConfig) 335 : make_string_error("unsupported function type"); 336 337 int exitCode = EXIT_SUCCESS; 338 llvm::handleAllErrors(std::move(error), 339 [&exitCode](const llvm::ErrorInfoBase &info) { 340 llvm::errs() << "Error: "; 341 info.log(llvm::errs()); 342 llvm::errs() << '\n'; 343 exitCode = EXIT_FAILURE; 344 }); 345 346 return exitCode; 347 } 348