1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This is a library that provides a shared implementation for command line 10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM 11 // IR before JIT-compiling and executing the latter. 12 // 13 // The translation can be customized by providing an MLIR to MLIR 14 // transformation. 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/ExecutionEngine/JitRunner.h" 18 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/ExecutionEngine/ExecutionEngine.h" 21 #include "mlir/ExecutionEngine/OptUtils.h" 22 #include "mlir/IR/BuiltinTypes.h" 23 #include "mlir/IR/MLIRContext.h" 24 #include "mlir/Parser/Parser.h" 25 #include "mlir/Support/FileUtilities.h" 26 27 #include "llvm/ADT/STLExtras.h" 28 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" 29 #include "llvm/IR/IRBuilder.h" 30 #include "llvm/IR/LLVMContext.h" 31 #include "llvm/IR/LegacyPassNameParser.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/FileUtilities.h" 34 #include "llvm/Support/SourceMgr.h" 35 #include "llvm/Support/StringSaver.h" 36 #include "llvm/Support/ToolOutputFile.h" 37 #include <cstdint> 38 #include <numeric> 39 #include <utility> 40 41 using namespace mlir; 42 using llvm::Error; 43 44 namespace { 45 /// This options struct prevents the need for global static initializers, and 46 /// is only initialized if the JITRunner is invoked. 47 struct Options { 48 llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional, 49 llvm::cl::desc("<input file>"), 50 llvm::cl::init("-")}; 51 llvm::cl::opt<std::string> mainFuncName{ 52 "e", llvm::cl::desc("The function to be called"), 53 llvm::cl::value_desc("<function name>"), llvm::cl::init("main")}; 54 llvm::cl::opt<std::string> mainFuncType{ 55 "entry-point-result", 56 llvm::cl::desc("Textual description of the function type to be called"), 57 llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")}; 58 59 llvm::cl::OptionCategory optFlags{"opt-like flags"}; 60 61 // CLI variables for -On options. 62 llvm::cl::opt<bool> optO0{"O0", 63 llvm::cl::desc("Run opt passes and codegen at O0"), 64 llvm::cl::cat(optFlags)}; 65 llvm::cl::opt<bool> optO1{"O1", 66 llvm::cl::desc("Run opt passes and codegen at O1"), 67 llvm::cl::cat(optFlags)}; 68 llvm::cl::opt<bool> optO2{"O2", 69 llvm::cl::desc("Run opt passes and codegen at O2"), 70 llvm::cl::cat(optFlags)}; 71 llvm::cl::opt<bool> optO3{"O3", 72 llvm::cl::desc("Run opt passes and codegen at O3"), 73 llvm::cl::cat(optFlags)}; 74 75 llvm::cl::OptionCategory clOptionsCategory{"linking options"}; 76 llvm::cl::list<std::string> clSharedLibs{ 77 "shared-libs", llvm::cl::desc("Libraries to link dynamically"), 78 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, 79 llvm::cl::cat(clOptionsCategory)}; 80 81 /// CLI variables for debugging. 82 llvm::cl::opt<bool> dumpObjectFile{ 83 "dump-object-file", 84 llvm::cl::desc("Dump JITted-compiled object to file specified with " 85 "-object-filename (<input file>.o by default).")}; 86 87 llvm::cl::opt<std::string> objectFilename{ 88 "object-filename", 89 llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")}; 90 }; 91 92 struct CompileAndExecuteConfig { 93 /// LLVM module transformer that is passed to ExecutionEngine. 94 std::function<llvm::Error(llvm::Module *)> transformer; 95 96 /// A custom function that is passed to ExecutionEngine. It processes MLIR 97 /// module and creates LLVM IR module. 98 llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp, 99 llvm::LLVMContext &)> 100 llvmModuleBuilder; 101 102 /// A custom function that is passed to ExecutinEngine to register symbols at 103 /// runtime. 104 llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)> 105 runtimeSymbolMap; 106 }; 107 108 } // namespace 109 110 static OwningOpRef<ModuleOp> parseMLIRInput(StringRef inputFilename, 111 MLIRContext *context) { 112 // Set up the input file. 113 std::string errorMessage; 114 auto file = openInputFile(inputFilename, &errorMessage); 115 if (!file) { 116 llvm::errs() << errorMessage << "\n"; 117 return nullptr; 118 } 119 120 llvm::SourceMgr sourceMgr; 121 sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); 122 return parseSourceFile<ModuleOp>(sourceMgr, context); 123 } 124 125 static inline Error makeStringError(const Twine &message) { 126 return llvm::make_error<llvm::StringError>(message.str(), 127 llvm::inconvertibleErrorCode()); 128 } 129 130 static Optional<unsigned> getCommandLineOptLevel(Options &options) { 131 Optional<unsigned> optLevel; 132 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ 133 options.optO0, options.optO1, options.optO2, options.optO3}; 134 135 // Determine if there is an optimization flag present. 136 for (unsigned j = 0; j < 4; ++j) { 137 auto &flag = optFlags[j].get(); 138 if (flag) { 139 optLevel = j; 140 break; 141 } 142 } 143 return optLevel; 144 } 145 146 // JIT-compile the given module and run "entryPoint" with "args" as arguments. 147 static Error compileAndExecute(Options &options, ModuleOp module, 148 StringRef entryPoint, 149 CompileAndExecuteConfig config, void **args) { 150 Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel; 151 if (auto clOptLevel = getCommandLineOptLevel(options)) 152 jitCodeGenOptLevel = 153 static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue()); 154 155 // If shared library implements custom mlir-runner library init and destroy 156 // functions, we'll use them to register the library with the execution 157 // engine. Otherwise we'll pass library directly to the execution engine. 158 SmallVector<SmallString<256>, 4> libPaths; 159 160 // Use absolute library path so that gdb can find the symbol table. 161 transform( 162 options.clSharedLibs, std::back_inserter(libPaths), 163 [](std::string libPath) { 164 SmallString<256> absPath(libPath.begin(), libPath.end()); 165 cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath))); 166 return absPath; 167 }); 168 169 // Libraries that we'll pass to the ExecutionEngine for loading. 170 SmallVector<StringRef, 4> executionEngineLibs; 171 172 using MlirRunnerInitFn = void (*)(llvm::StringMap<void *> &); 173 using MlirRunnerDestroyFn = void (*)(); 174 175 llvm::StringMap<void *> exportSymbols; 176 SmallVector<MlirRunnerDestroyFn> destroyFns; 177 178 // Handle libraries that do support mlir-runner init/destroy callbacks. 179 for (auto &libPath : libPaths) { 180 auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str()); 181 void *initSym = lib.getAddressOfSymbol("__mlir_runner_init"); 182 void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy"); 183 184 // Library does not support mlir runner, load it with ExecutionEngine. 185 if (!initSym || !destroySim) { 186 executionEngineLibs.push_back(libPath); 187 continue; 188 } 189 190 auto initFn = reinterpret_cast<MlirRunnerInitFn>(initSym); 191 initFn(exportSymbols); 192 193 auto destroyFn = reinterpret_cast<MlirRunnerDestroyFn>(destroySim); 194 destroyFns.push_back(destroyFn); 195 } 196 197 // Build a runtime symbol map from the config and exported symbols. 198 auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) { 199 auto symbolMap = config.runtimeSymbolMap ? config.runtimeSymbolMap(interner) 200 : llvm::orc::SymbolMap(); 201 for (auto &exportSymbol : exportSymbols) 202 symbolMap[interner(exportSymbol.getKey())] = 203 llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue()); 204 return symbolMap; 205 }; 206 207 mlir::ExecutionEngineOptions engineOptions; 208 engineOptions.llvmModuleBuilder = config.llvmModuleBuilder; 209 if (config.transformer) 210 engineOptions.transformer = config.transformer; 211 engineOptions.jitCodeGenOptLevel = jitCodeGenOptLevel; 212 engineOptions.sharedLibPaths = executionEngineLibs; 213 engineOptions.enableObjectCache = true; 214 auto expectedEngine = mlir::ExecutionEngine::create(module, engineOptions); 215 if (!expectedEngine) 216 return expectedEngine.takeError(); 217 218 auto engine = std::move(*expectedEngine); 219 engine->registerSymbols(runtimeSymbolMap); 220 221 auto expectedFPtr = engine->lookupPacked(entryPoint); 222 if (!expectedFPtr) 223 return expectedFPtr.takeError(); 224 225 if (options.dumpObjectFile) 226 engine->dumpToObjectFile(options.objectFilename.empty() 227 ? options.inputFilename + ".o" 228 : options.objectFilename); 229 230 void (*fptr)(void **) = *expectedFPtr; 231 (*fptr)(args); 232 233 // Run all dynamic library destroy callbacks to prepare for the shutdown. 234 llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); }); 235 236 return Error::success(); 237 } 238 239 static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module, 240 StringRef entryPoint, 241 CompileAndExecuteConfig config) { 242 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); 243 if (!mainFunction || mainFunction.empty()) 244 return makeStringError("entry point not found"); 245 void *empty = nullptr; 246 return compileAndExecute(options, module, entryPoint, std::move(config), 247 &empty); 248 } 249 250 template <typename Type> 251 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction); 252 template <> 253 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) { 254 auto resultType = mainFunction.getFunctionType() 255 .cast<LLVM::LLVMFunctionType>() 256 .getReturnType() 257 .dyn_cast<IntegerType>(); 258 if (!resultType || resultType.getWidth() != 32) 259 return makeStringError("only single i32 function result supported"); 260 return Error::success(); 261 } 262 template <> 263 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) { 264 auto resultType = mainFunction.getFunctionType() 265 .cast<LLVM::LLVMFunctionType>() 266 .getReturnType() 267 .dyn_cast<IntegerType>(); 268 if (!resultType || resultType.getWidth() != 64) 269 return makeStringError("only single i64 function result supported"); 270 return Error::success(); 271 } 272 template <> 273 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) { 274 if (!mainFunction.getFunctionType() 275 .cast<LLVM::LLVMFunctionType>() 276 .getReturnType() 277 .isa<Float32Type>()) 278 return makeStringError("only single f32 function result supported"); 279 return Error::success(); 280 } 281 template <typename Type> 282 Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module, 283 StringRef entryPoint, 284 CompileAndExecuteConfig config) { 285 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); 286 if (!mainFunction || mainFunction.isExternal()) 287 return makeStringError("entry point not found"); 288 289 if (mainFunction.getFunctionType() 290 .cast<LLVM::LLVMFunctionType>() 291 .getNumParams() != 0) 292 return makeStringError("function inputs not supported"); 293 294 if (Error error = checkCompatibleReturnType<Type>(mainFunction)) 295 return error; 296 297 Type res; 298 struct { 299 void *data; 300 } data; 301 data.data = &res; 302 if (auto error = compileAndExecute(options, module, entryPoint, 303 std::move(config), (void **)&data)) 304 return error; 305 306 // Intentional printing of the output so we can test. 307 llvm::outs() << res << '\n'; 308 309 return Error::success(); 310 } 311 312 /// Entry point for all CPU runners. Expects the common argc/argv arguments for 313 /// standard C++ main functions. 314 int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry ®istry, 315 JitRunnerConfig config) { 316 // Create the options struct containing the command line options for the 317 // runner. This must come before the command line options are parsed. 318 Options options; 319 llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n"); 320 321 Optional<unsigned> optLevel = getCommandLineOptLevel(options); 322 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ 323 options.optO0, options.optO1, options.optO2, options.optO3}; 324 325 MLIRContext context(registry); 326 327 auto m = parseMLIRInput(options.inputFilename, &context); 328 if (!m) { 329 llvm::errs() << "could not parse the input IR\n"; 330 return 1; 331 } 332 333 if (config.mlirTransformer) 334 if (failed(config.mlirTransformer(m.get()))) 335 return EXIT_FAILURE; 336 337 auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); 338 if (!tmBuilderOrError) { 339 llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n"; 340 return EXIT_FAILURE; 341 } 342 auto tmOrError = tmBuilderOrError->createTargetMachine(); 343 if (!tmOrError) { 344 llvm::errs() << "Failed to create a TargetMachine for the host\n"; 345 return EXIT_FAILURE; 346 } 347 348 CompileAndExecuteConfig compileAndExecuteConfig; 349 if (optLevel) { 350 compileAndExecuteConfig.transformer = mlir::makeOptimizingTransformer( 351 *optLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get()); 352 } 353 compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder; 354 compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap; 355 356 // Get the function used to compile and execute the module. 357 using CompileAndExecuteFnT = 358 Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig); 359 auto compileAndExecuteFn = 360 StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue()) 361 .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>) 362 .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>) 363 .Case("f32", compileAndExecuteSingleReturnFunction<float>) 364 .Case("void", compileAndExecuteVoidFunction) 365 .Default(nullptr); 366 367 Error error = compileAndExecuteFn 368 ? compileAndExecuteFn(options, m.get(), 369 options.mainFuncName.getValue(), 370 compileAndExecuteConfig) 371 : makeStringError("unsupported function type"); 372 373 int exitCode = EXIT_SUCCESS; 374 llvm::handleAllErrors(std::move(error), 375 [&exitCode](const llvm::ErrorInfoBase &info) { 376 llvm::errs() << "Error: "; 377 info.log(llvm::errs()); 378 llvm::errs() << '\n'; 379 exitCode = EXIT_FAILURE; 380 }); 381 382 return exitCode; 383 } 384