1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This is a library that provides a shared implementation for command line 10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM 11 // IR before JIT-compiling and executing the latter. 12 // 13 // The translation can be customized by providing an MLIR to MLIR 14 // transformation. 15 //===----------------------------------------------------------------------===// 16 17 #include "mlir/ExecutionEngine/JitRunner.h" 18 19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h" 20 #include "mlir/ExecutionEngine/ExecutionEngine.h" 21 #include "mlir/ExecutionEngine/OptUtils.h" 22 #include "mlir/IR/BuiltinTypes.h" 23 #include "mlir/IR/MLIRContext.h" 24 #include "mlir/Parser/Parser.h" 25 #include "mlir/Support/FileUtilities.h" 26 27 #include "llvm/ADT/STLExtras.h" 28 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" 29 #include "llvm/IR/IRBuilder.h" 30 #include "llvm/IR/LLVMContext.h" 31 #include "llvm/IR/LegacyPassNameParser.h" 32 #include "llvm/Support/CommandLine.h" 33 #include "llvm/Support/FileUtilities.h" 34 #include "llvm/Support/SourceMgr.h" 35 #include "llvm/Support/StringSaver.h" 36 #include "llvm/Support/ToolOutputFile.h" 37 #include <cstdint> 38 #include <numeric> 39 40 using namespace mlir; 41 using llvm::Error; 42 43 namespace { 44 /// This options struct prevents the need for global static initializers, and 45 /// is only initialized if the JITRunner is invoked. 46 struct Options { 47 llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional, 48 llvm::cl::desc("<input file>"), 49 llvm::cl::init("-")}; 50 llvm::cl::opt<std::string> mainFuncName{ 51 "e", llvm::cl::desc("The function to be called"), 52 llvm::cl::value_desc("<function name>"), llvm::cl::init("main")}; 53 llvm::cl::opt<std::string> mainFuncType{ 54 "entry-point-result", 55 llvm::cl::desc("Textual description of the function type to be called"), 56 llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")}; 57 58 llvm::cl::OptionCategory optFlags{"opt-like flags"}; 59 60 // CLI variables for -On options. 61 llvm::cl::opt<bool> optO0{"O0", 62 llvm::cl::desc("Run opt passes and codegen at O0"), 63 llvm::cl::cat(optFlags)}; 64 llvm::cl::opt<bool> optO1{"O1", 65 llvm::cl::desc("Run opt passes and codegen at O1"), 66 llvm::cl::cat(optFlags)}; 67 llvm::cl::opt<bool> optO2{"O2", 68 llvm::cl::desc("Run opt passes and codegen at O2"), 69 llvm::cl::cat(optFlags)}; 70 llvm::cl::opt<bool> optO3{"O3", 71 llvm::cl::desc("Run opt passes and codegen at O3"), 72 llvm::cl::cat(optFlags)}; 73 74 llvm::cl::OptionCategory clOptionsCategory{"linking options"}; 75 llvm::cl::list<std::string> clSharedLibs{ 76 "shared-libs", llvm::cl::desc("Libraries to link dynamically"), 77 llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated, 78 llvm::cl::cat(clOptionsCategory)}; 79 80 /// CLI variables for debugging. 81 llvm::cl::opt<bool> dumpObjectFile{ 82 "dump-object-file", 83 llvm::cl::desc("Dump JITted-compiled object to file specified with " 84 "-object-filename (<input file>.o by default).")}; 85 86 llvm::cl::opt<std::string> objectFilename{ 87 "object-filename", 88 llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")}; 89 }; 90 91 struct CompileAndExecuteConfig { 92 /// LLVM module transformer that is passed to ExecutionEngine. 93 std::function<llvm::Error(llvm::Module *)> transformer; 94 95 /// A custom function that is passed to ExecutionEngine. It processes MLIR 96 /// module and creates LLVM IR module. 97 llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp, 98 llvm::LLVMContext &)> 99 llvmModuleBuilder; 100 101 /// A custom function that is passed to ExecutinEngine to register symbols at 102 /// runtime. 103 llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)> 104 runtimeSymbolMap; 105 }; 106 107 } // namespace 108 109 static OwningOpRef<ModuleOp> parseMLIRInput(StringRef inputFilename, 110 MLIRContext *context) { 111 // Set up the input file. 112 std::string errorMessage; 113 auto file = openInputFile(inputFilename, &errorMessage); 114 if (!file) { 115 llvm::errs() << errorMessage << "\n"; 116 return nullptr; 117 } 118 119 llvm::SourceMgr sourceMgr; 120 sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); 121 return parseSourceFile<ModuleOp>(sourceMgr, context); 122 } 123 124 static inline Error makeStringError(const Twine &message) { 125 return llvm::make_error<llvm::StringError>(message.str(), 126 llvm::inconvertibleErrorCode()); 127 } 128 129 static Optional<unsigned> getCommandLineOptLevel(Options &options) { 130 Optional<unsigned> optLevel; 131 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ 132 options.optO0, options.optO1, options.optO2, options.optO3}; 133 134 // Determine if there is an optimization flag present. 135 for (unsigned j = 0; j < 4; ++j) { 136 auto &flag = optFlags[j].get(); 137 if (flag) { 138 optLevel = j; 139 break; 140 } 141 } 142 return optLevel; 143 } 144 145 // JIT-compile the given module and run "entryPoint" with "args" as arguments. 146 static Error compileAndExecute(Options &options, ModuleOp module, 147 StringRef entryPoint, 148 CompileAndExecuteConfig config, void **args) { 149 Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel; 150 if (auto clOptLevel = getCommandLineOptLevel(options)) 151 jitCodeGenOptLevel = 152 static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue()); 153 154 // If shared library implements custom mlir-runner library init and destroy 155 // functions, we'll use them to register the library with the execution 156 // engine. Otherwise we'll pass library directly to the execution engine. 157 SmallVector<SmallString<256>, 4> libPaths; 158 159 // Use absolute library path so that gdb can find the symbol table. 160 transform( 161 options.clSharedLibs, std::back_inserter(libPaths), 162 [](std::string libPath) { 163 SmallString<256> absPath(libPath.begin(), libPath.end()); 164 cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath))); 165 return absPath; 166 }); 167 168 // Libraries that we'll pass to the ExecutionEngine for loading. 169 SmallVector<StringRef, 4> executionEngineLibs; 170 171 using MlirRunnerInitFn = void (*)(llvm::StringMap<void *> &); 172 using MlirRunnerDestroyFn = void (*)(); 173 174 llvm::StringMap<void *> exportSymbols; 175 SmallVector<MlirRunnerDestroyFn> destroyFns; 176 177 // Handle libraries that do support mlir-runner init/destroy callbacks. 178 for (auto &libPath : libPaths) { 179 auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str()); 180 void *initSym = lib.getAddressOfSymbol("__mlir_runner_init"); 181 void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy"); 182 183 // Library does not support mlir runner, load it with ExecutionEngine. 184 if (!initSym || !destroySim) { 185 executionEngineLibs.push_back(libPath); 186 continue; 187 } 188 189 auto initFn = reinterpret_cast<MlirRunnerInitFn>(initSym); 190 initFn(exportSymbols); 191 192 auto destroyFn = reinterpret_cast<MlirRunnerDestroyFn>(destroySim); 193 destroyFns.push_back(destroyFn); 194 } 195 196 // Build a runtime symbol map from the config and exported symbols. 197 auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) { 198 auto symbolMap = config.runtimeSymbolMap ? config.runtimeSymbolMap(interner) 199 : llvm::orc::SymbolMap(); 200 for (auto &exportSymbol : exportSymbols) 201 symbolMap[interner(exportSymbol.getKey())] = 202 llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue()); 203 return symbolMap; 204 }; 205 206 mlir::ExecutionEngineOptions engineOptions; 207 engineOptions.llvmModuleBuilder = config.llvmModuleBuilder; 208 engineOptions.transformer = config.transformer; 209 engineOptions.jitCodeGenOptLevel = jitCodeGenOptLevel; 210 engineOptions.sharedLibPaths = executionEngineLibs; 211 engineOptions.enableObjectCache = true; 212 auto expectedEngine = mlir::ExecutionEngine::create(module, engineOptions); 213 if (!expectedEngine) 214 return expectedEngine.takeError(); 215 216 auto engine = std::move(*expectedEngine); 217 engine->registerSymbols(runtimeSymbolMap); 218 219 auto expectedFPtr = engine->lookupPacked(entryPoint); 220 if (!expectedFPtr) 221 return expectedFPtr.takeError(); 222 223 if (options.dumpObjectFile) 224 engine->dumpToObjectFile(options.objectFilename.empty() 225 ? options.inputFilename + ".o" 226 : options.objectFilename); 227 228 void (*fptr)(void **) = *expectedFPtr; 229 (*fptr)(args); 230 231 // Run all dynamic library destroy callbacks to prepare for the shutdown. 232 llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); }); 233 234 return Error::success(); 235 } 236 237 static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module, 238 StringRef entryPoint, 239 CompileAndExecuteConfig config) { 240 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); 241 if (!mainFunction || mainFunction.empty()) 242 return makeStringError("entry point not found"); 243 void *empty = nullptr; 244 return compileAndExecute(options, module, entryPoint, config, &empty); 245 } 246 247 template <typename Type> 248 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction); 249 template <> 250 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) { 251 auto resultType = mainFunction.getFunctionType() 252 .cast<LLVM::LLVMFunctionType>() 253 .getReturnType() 254 .dyn_cast<IntegerType>(); 255 if (!resultType || resultType.getWidth() != 32) 256 return makeStringError("only single i32 function result supported"); 257 return Error::success(); 258 } 259 template <> 260 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) { 261 auto resultType = mainFunction.getFunctionType() 262 .cast<LLVM::LLVMFunctionType>() 263 .getReturnType() 264 .dyn_cast<IntegerType>(); 265 if (!resultType || resultType.getWidth() != 64) 266 return makeStringError("only single i64 function result supported"); 267 return Error::success(); 268 } 269 template <> 270 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) { 271 if (!mainFunction.getFunctionType() 272 .cast<LLVM::LLVMFunctionType>() 273 .getReturnType() 274 .isa<Float32Type>()) 275 return makeStringError("only single f32 function result supported"); 276 return Error::success(); 277 } 278 template <typename Type> 279 Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module, 280 StringRef entryPoint, 281 CompileAndExecuteConfig config) { 282 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint); 283 if (!mainFunction || mainFunction.isExternal()) 284 return makeStringError("entry point not found"); 285 286 if (mainFunction.getFunctionType() 287 .cast<LLVM::LLVMFunctionType>() 288 .getNumParams() != 0) 289 return makeStringError("function inputs not supported"); 290 291 if (Error error = checkCompatibleReturnType<Type>(mainFunction)) 292 return error; 293 294 Type res; 295 struct { 296 void *data; 297 } data; 298 data.data = &res; 299 if (auto error = compileAndExecute(options, module, entryPoint, config, 300 (void **)&data)) 301 return error; 302 303 // Intentional printing of the output so we can test. 304 llvm::outs() << res << '\n'; 305 306 return Error::success(); 307 } 308 309 /// Entry point for all CPU runners. Expects the common argc/argv arguments for 310 /// standard C++ main functions. 311 int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry ®istry, 312 JitRunnerConfig config) { 313 // Create the options struct containing the command line options for the 314 // runner. This must come before the command line options are parsed. 315 Options options; 316 llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n"); 317 318 Optional<unsigned> optLevel = getCommandLineOptLevel(options); 319 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{ 320 options.optO0, options.optO1, options.optO2, options.optO3}; 321 322 MLIRContext context(registry); 323 324 auto m = parseMLIRInput(options.inputFilename, &context); 325 if (!m) { 326 llvm::errs() << "could not parse the input IR\n"; 327 return 1; 328 } 329 330 if (config.mlirTransformer) 331 if (failed(config.mlirTransformer(m.get()))) 332 return EXIT_FAILURE; 333 334 auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost(); 335 if (!tmBuilderOrError) { 336 llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n"; 337 return EXIT_FAILURE; 338 } 339 auto tmOrError = tmBuilderOrError->createTargetMachine(); 340 if (!tmOrError) { 341 llvm::errs() << "Failed to create a TargetMachine for the host\n"; 342 return EXIT_FAILURE; 343 } 344 345 CompileAndExecuteConfig compileAndExecuteConfig; 346 if (optLevel) { 347 compileAndExecuteConfig.transformer = mlir::makeOptimizingTransformer( 348 *optLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get()); 349 } 350 compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder; 351 compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap; 352 353 // Get the function used to compile and execute the module. 354 using CompileAndExecuteFnT = 355 Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig); 356 auto compileAndExecuteFn = 357 StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue()) 358 .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>) 359 .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>) 360 .Case("f32", compileAndExecuteSingleReturnFunction<float>) 361 .Case("void", compileAndExecuteVoidFunction) 362 .Default(nullptr); 363 364 Error error = compileAndExecuteFn 365 ? compileAndExecuteFn(options, m.get(), 366 options.mainFuncName.getValue(), 367 compileAndExecuteConfig) 368 : makeStringError("unsupported function type"); 369 370 int exitCode = EXIT_SUCCESS; 371 llvm::handleAllErrors(std::move(error), 372 [&exitCode](const llvm::ErrorInfoBase &info) { 373 llvm::errs() << "Error: "; 374 info.log(llvm::errs()); 375 llvm::errs() << '\n'; 376 exitCode = EXIT_FAILURE; 377 }); 378 379 return exitCode; 380 } 381