1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This is a library that provides a shared implementation for command line 10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM 11 // IR before JIT-compiling and executing the latter. 12 // 13 // The translation can be customized by providing an MLIR to MLIR 14 // transformation. 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/ExecutionEngine/JitRunner.h" 18 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/ExecutionEngine/ExecutionEngine.h" 21 #include "mlir/ExecutionEngine/OptUtils.h" 22 #include "mlir/IR/BuiltinTypes.h" 23 #include "mlir/IR/MLIRContext.h" 24 #include "mlir/Parser/Parser.h" 25 #include "mlir/Support/FileUtilities.h" 26 27 #include "llvm/ADT/STLExtras.h" 28 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" 29 #include "llvm/IR/IRBuilder.h" 30 #include "llvm/IR/LLVMContext.h" 31 #include "llvm/IR/LegacyPassNameParser.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/FileUtilities.h" 34 #include "llvm/Support/SourceMgr.h" 35 #include "llvm/Support/StringSaver.h" 36 #include "llvm/Support/ToolOutputFile.h" 37 #include <cstdint> 38 #include <numeric> 39 40 using namespace mlir; 41 using llvm::Error; 42 43 namespace { 44 /// This options struct prevents the need for global static initializers, and 45 /// is only initialized if the JITRunner is invoked. 46 struct Options { 47 llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional, 48 llvm::cl::desc("<input file>"), 49 llvm::cl::init("-")}; 50 llvm::cl::opt<std::string> mainFuncName{ 51 "e", llvm::cl::desc("The function to be called"), 52 llvm::cl::value_desc("<function name>"), llvm::cl::init("main")}; 53 llvm::cl::opt<std::string> mainFuncType{ 54 "entry-point-result", 55 llvm::cl::desc("Textual description of the function type to be called"), 56 llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")}; 57 58 llvm::cl::OptionCategory optFlags{"opt-like flags"}; 59 60 // CLI list of pass information 61 llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser> llvmPasses{ 62 llvm::cl::desc("LLVM optimizing passes to run"), llvm::cl::cat(optFlags)}; 63 64 // CLI variables for -On options. 65 llvm::cl::opt<bool> optO0{"O0", 66 llvm::cl::desc("Run opt passes and codegen at O0"), 67 llvm::cl::cat(optFlags)}; 68 llvm::cl::opt<bool> optO1{"O1", 69 llvm::cl::desc("Run opt passes and codegen at O1"), 70 llvm::cl::cat(optFlags)}; 71 llvm::cl::opt<bool> optO2{"O2", 72 llvm::cl::desc("Run opt passes and codegen at O2"), 73 llvm::cl::cat(optFlags)}; 74 llvm::cl::opt<bool> optO3{"O3", 75 llvm::cl::desc("Run opt passes and codegen at O3"), 76 llvm::cl::cat(optFlags)}; 77 78 llvm::cl::OptionCategory clOptionsCategory{"linking options"}; 79 llvm::cl::list<std::string> clSharedLibs{ 80 "shared-libs", llvm::cl::desc("Libraries to link dynamically"), 81 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, 82 llvm::cl::cat(clOptionsCategory)}; 83 84 /// CLI variables for debugging. 85 llvm::cl::opt<bool> dumpObjectFile{ 86 "dump-object-file", 87 llvm::cl::desc("Dump JITted-compiled object to file specified with " 88 "-object-filename (<input file>.o by default).")}; 89 90 llvm::cl::opt<std::string> objectFilename{ 91 "object-filename", 92 llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")}; 93 }; 94 95 struct CompileAndExecuteConfig { 96 /// LLVM module transformer that is passed to ExecutionEngine. 97 llvm::function_ref<llvm::Error(llvm::Module *)> transformer; 98 99 /// A custom function that is passed to ExecutionEngine. It processes MLIR 100 /// module and creates LLVM IR module. 101 llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp, 102 llvm::LLVMContext &)> 103 llvmModuleBuilder; 104 105 /// A custom function that is passed to ExecutinEngine to register symbols at 106 /// runtime. 107 llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)> 108 runtimeSymbolMap; 109 }; 110 111 } // namespace 112 113 static OwningOpRef<ModuleOp> parseMLIRInput(StringRef inputFilename, 114 MLIRContext *context) { 115 // Set up the input file. 116 std::string errorMessage; 117 auto file = openInputFile(inputFilename, &errorMessage); 118 if (!file) { 119 llvm::errs() << errorMessage << "\n"; 120 return nullptr; 121 } 122 123 llvm::SourceMgr sourceMgr; 124 sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); 125 return parseSourceFile<ModuleOp>(sourceMgr, context); 126 } 127 128 static inline Error makeStringError(const Twine &message) { 129 return llvm::make_error<llvm::StringError>(message.str(), 130 llvm::inconvertibleErrorCode()); 131 } 132 133 static Optional<unsigned> getCommandLineOptLevel(Options &options) { 134 Optional<unsigned> optLevel; 135 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ 136 options.optO0, options.optO1, options.optO2, options.optO3}; 137 138 // Determine if there is an optimization flag present. 139 for (unsigned j = 0; j < 4; ++j) { 140 auto &flag = optFlags[j].get(); 141 if (flag) { 142 optLevel = j; 143 break; 144 } 145 } 146 return optLevel; 147 } 148 149 // JIT-compile the given module and run "entryPoint" with "args" as arguments. 150 static Error compileAndExecute(Options &options, ModuleOp module, 151 StringRef entryPoint, 152 CompileAndExecuteConfig config, void **args) { 153 Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel; 154 if (auto clOptLevel = getCommandLineOptLevel(options)) 155 jitCodeGenOptLevel = 156 static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue()); 157 158 // If shared library implements custom mlir-runner library init and destroy 159 // functions, we'll use them to register the library with the execution 160 // engine. Otherwise we'll pass library directly to the execution engine. 161 SmallVector<SmallString<256>, 4> libPaths; 162 163 // Use absolute library path so that gdb can find the symbol table. 164 transform( 165 options.clSharedLibs, std::back_inserter(libPaths), 166 [](std::string libPath) { 167 SmallString<256> absPath(libPath.begin(), libPath.end()); 168 cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath))); 169 return absPath; 170 }); 171 172 // Libraries that we'll pass to the ExecutionEngine for loading. 173 SmallVector<StringRef, 4> executionEngineLibs; 174 175 using MlirRunnerInitFn = void (*)(llvm::StringMap<void *> &); 176 using MlirRunnerDestroyFn = void (*)(); 177 178 llvm::StringMap<void *> exportSymbols; 179 SmallVector<MlirRunnerDestroyFn> destroyFns; 180 181 // Handle libraries that do support mlir-runner init/destroy callbacks. 182 for (auto &libPath : libPaths) { 183 auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str()); 184 void *initSym = lib.getAddressOfSymbol("__mlir_runner_init"); 185 void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy"); 186 187 // Library does not support mlir runner, load it with ExecutionEngine. 188 if (!initSym || !destroySim) { 189 executionEngineLibs.push_back(libPath); 190 continue; 191 } 192 193 auto initFn = reinterpret_cast<MlirRunnerInitFn>(initSym); 194 initFn(exportSymbols); 195 196 auto destroyFn = reinterpret_cast<MlirRunnerDestroyFn>(destroySim); 197 destroyFns.push_back(destroyFn); 198 } 199 200 // Build a runtime symbol map from the config and exported symbols. 201 auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) { 202 auto symbolMap = config.runtimeSymbolMap ? config.runtimeSymbolMap(interner) 203 : llvm::orc::SymbolMap(); 204 for (auto &exportSymbol : exportSymbols) 205 symbolMap[interner(exportSymbol.getKey())] = 206 llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue()); 207 return symbolMap; 208 }; 209 210 mlir::ExecutionEngineOptions engineOptions; 211 engineOptions.llvmModuleBuilder = config.llvmModuleBuilder; 212 engineOptions.transformer = config.transformer; 213 engineOptions.jitCodeGenOptLevel = jitCodeGenOptLevel; 214 engineOptions.sharedLibPaths = executionEngineLibs; 215 engineOptions.enableObjectCache = true; 216 auto expectedEngine = mlir::ExecutionEngine::create(module, engineOptions); 217 if (!expectedEngine) 218 return expectedEngine.takeError(); 219 220 auto engine = std::move(*expectedEngine); 221 engine->registerSymbols(runtimeSymbolMap); 222 223 auto expectedFPtr = engine->lookupPacked(entryPoint); 224 if (!expectedFPtr) 225 return expectedFPtr.takeError(); 226 227 if (options.dumpObjectFile) 228 engine->dumpToObjectFile(options.objectFilename.empty() 229 ? options.inputFilename + ".o" 230 : options.objectFilename); 231 232 void (*fptr)(void **) = *expectedFPtr; 233 (*fptr)(args); 234 235 // Run all dynamic library destroy callbacks to prepare for the shutdown. 236 llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); }); 237 238 return Error::success(); 239 } 240 241 static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module, 242 StringRef entryPoint, 243 CompileAndExecuteConfig config) { 244 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); 245 if (!mainFunction || mainFunction.empty()) 246 return makeStringError("entry point not found"); 247 void *empty = nullptr; 248 return compileAndExecute(options, module, entryPoint, config, &empty); 249 } 250 251 template <typename Type> 252 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction); 253 template <> 254 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) { 255 auto resultType = mainFunction.getFunctionType() 256 .cast<LLVM::LLVMFunctionType>() 257 .getReturnType() 258 .dyn_cast<IntegerType>(); 259 if (!resultType || resultType.getWidth() != 32) 260 return makeStringError("only single i32 function result supported"); 261 return Error::success(); 262 } 263 template <> 264 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) { 265 auto resultType = mainFunction.getFunctionType() 266 .cast<LLVM::LLVMFunctionType>() 267 .getReturnType() 268 .dyn_cast<IntegerType>(); 269 if (!resultType || resultType.getWidth() != 64) 270 return makeStringError("only single i64 function result supported"); 271 return Error::success(); 272 } 273 template <> 274 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) { 275 if (!mainFunction.getFunctionType() 276 .cast<LLVM::LLVMFunctionType>() 277 .getReturnType() 278 .isa<Float32Type>()) 279 return makeStringError("only single f32 function result supported"); 280 return Error::success(); 281 } 282 template <typename Type> 283 Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module, 284 StringRef entryPoint, 285 CompileAndExecuteConfig config) { 286 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); 287 if (!mainFunction || mainFunction.isExternal()) 288 return makeStringError("entry point not found"); 289 290 if (mainFunction.getFunctionType() 291 .cast<LLVM::LLVMFunctionType>() 292 .getNumParams() != 0) 293 return makeStringError("function inputs not supported"); 294 295 if (Error error = checkCompatibleReturnType<Type>(mainFunction)) 296 return error; 297 298 Type res; 299 struct { 300 void *data; 301 } data; 302 data.data = &res; 303 if (auto error = compileAndExecute(options, module, entryPoint, config, 304 (void **)&data)) 305 return error; 306 307 // Intentional printing of the output so we can test. 308 llvm::outs() << res << '\n'; 309 310 return Error::success(); 311 } 312 313 /// Entry point for all CPU runners. Expects the common argc/argv arguments for 314 /// standard C++ main functions. 315 int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry ®istry, 316 JitRunnerConfig config) { 317 // Create the options struct containing the command line options for the 318 // runner. This must come before the command line options are parsed. 319 Options options; 320 llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n"); 321 322 Optional<unsigned> optLevel = getCommandLineOptLevel(options); 323 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ 324 options.optO0, options.optO1, options.optO2, options.optO3}; 325 unsigned optCLIPosition = 0; 326 // Determine if there is an optimization flag present, and its CLI position 327 // (optCLIPosition). 328 for (unsigned j = 0; j < 4; ++j) { 329 auto &flag = optFlags[j].get(); 330 if (flag) { 331 optCLIPosition = flag.getPosition(); 332 break; 333 } 334 } 335 // Generate vector of pass information, plus the index at which we should 336 // insert any optimization passes in that vector (optPosition). 337 SmallVector<const llvm::PassInfo *, 4> passes; 338 unsigned optPosition = 0; 339 for (unsigned i = 0, e = options.llvmPasses.size(); i < e; ++i) { 340 passes.push_back(options.llvmPasses[i]); 341 if (optCLIPosition < options.llvmPasses.getPosition(i)) { 342 optPosition = i; 343 optCLIPosition = UINT_MAX; // To ensure we never insert again 344 } 345 } 346 347 MLIRContext context(registry); 348 349 auto m = parseMLIRInput(options.inputFilename, &context); 350 if (!m) { 351 llvm::errs() << "could not parse the input IR\n"; 352 return 1; 353 } 354 355 if (config.mlirTransformer) 356 if (failed(config.mlirTransformer(m.get()))) 357 return EXIT_FAILURE; 358 359 auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); 360 if (!tmBuilderOrError) { 361 llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n"; 362 return EXIT_FAILURE; 363 } 364 auto tmOrError = tmBuilderOrError->createTargetMachine(); 365 if (!tmOrError) { 366 llvm::errs() << "Failed to create a TargetMachine for the host\n"; 367 return EXIT_FAILURE; 368 } 369 370 auto transformer = mlir::makeLLVMPassesTransformer( 371 passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition); 372 373 CompileAndExecuteConfig compileAndExecuteConfig; 374 compileAndExecuteConfig.transformer = transformer; 375 compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder; 376 compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap; 377 378 // Get the function used to compile and execute the module. 379 using CompileAndExecuteFnT = 380 Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig); 381 auto compileAndExecuteFn = 382 StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue()) 383 .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>) 384 .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>) 385 .Case("f32", compileAndExecuteSingleReturnFunction<float>) 386 .Case("void", compileAndExecuteVoidFunction) 387 .Default(nullptr); 388 389 Error error = compileAndExecuteFn 390 ? compileAndExecuteFn(options, m.get(), 391 options.mainFuncName.getValue(), 392 compileAndExecuteConfig) 393 : makeStringError("unsupported function type"); 394 395 int exitCode = EXIT_SUCCESS; 396 llvm::handleAllErrors(std::move(error), 397 [&exitCode](const llvm::ErrorInfoBase &info) { 398 llvm::errs() << "Error: "; 399 info.log(llvm::errs()); 400 llvm::errs() << '\n'; 401 exitCode = EXIT_FAILURE; 402 }); 403 404 return exitCode; 405 } 406