1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a library that provides a shared implementation for command line
10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
11 // IR before JIT-compiling and executing the latter.
12 //
13 // The translation can be customized by providing an MLIR to MLIR
14 // transformation.
15 //===----------------------------------------------------------------------===//
16
17 #include "mlir/ExecutionEngine/JitRunner.h"
18
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/ExecutionEngine/ExecutionEngine.h"
21 #include "mlir/ExecutionEngine/OptUtils.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/Parser/Parser.h"
25 #include "mlir/Support/FileUtilities.h"
26
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
29 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/LLVMContext.h"
32 #include "llvm/IR/LegacyPassNameParser.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/FileUtilities.h"
35 #include "llvm/Support/SourceMgr.h"
36 #include "llvm/Support/StringSaver.h"
37 #include "llvm/Support/ToolOutputFile.h"
38 #include <cstdint>
39 #include <numeric>
40 #include <utility>
41
42 using namespace mlir;
43 using llvm::Error;
44
45 namespace {
46 /// This options struct prevents the need for global static initializers, and
47 /// is only initialized if the JITRunner is invoked.
48 struct Options {
49 llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional,
50 llvm::cl::desc("<input file>"),
51 llvm::cl::init("-")};
52 llvm::cl::opt<std::string> mainFuncName{
53 "e", llvm::cl::desc("The function to be called"),
54 llvm::cl::value_desc("<function name>"), llvm::cl::init("main")};
55 llvm::cl::opt<std::string> mainFuncType{
56 "entry-point-result",
57 llvm::cl::desc("Textual description of the function type to be called"),
58 llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")};
59
60 llvm::cl::OptionCategory optFlags{"opt-like flags"};
61
62 // CLI variables for -On options.
63 llvm::cl::opt<bool> optO0{"O0",
64 llvm::cl::desc("Run opt passes and codegen at O0"),
65 llvm::cl::cat(optFlags)};
66 llvm::cl::opt<bool> optO1{"O1",
67 llvm::cl::desc("Run opt passes and codegen at O1"),
68 llvm::cl::cat(optFlags)};
69 llvm::cl::opt<bool> optO2{"O2",
70 llvm::cl::desc("Run opt passes and codegen at O2"),
71 llvm::cl::cat(optFlags)};
72 llvm::cl::opt<bool> optO3{"O3",
73 llvm::cl::desc("Run opt passes and codegen at O3"),
74 llvm::cl::cat(optFlags)};
75
76 llvm::cl::OptionCategory clOptionsCategory{"linking options"};
77 llvm::cl::list<std::string> clSharedLibs{
78 "shared-libs", llvm::cl::desc("Libraries to link dynamically"),
79 llvm::cl::MiscFlags::CommaSeparated, llvm::cl::cat(clOptionsCategory)};
80
81 /// CLI variables for debugging.
82 llvm::cl::opt<bool> dumpObjectFile{
83 "dump-object-file",
84 llvm::cl::desc("Dump JITted-compiled object to file specified with "
85 "-object-filename (<input file>.o by default).")};
86
87 llvm::cl::opt<std::string> objectFilename{
88 "object-filename",
89 llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
90
91 llvm::cl::opt<bool> hostSupportsJit{"host-supports-jit",
92 llvm::cl::desc("Report host JIT support"),
93 llvm::cl::Hidden};
94 };
95
96 struct CompileAndExecuteConfig {
97 /// LLVM module transformer that is passed to ExecutionEngine.
98 std::function<llvm::Error(llvm::Module *)> transformer;
99
100 /// A custom function that is passed to ExecutionEngine. It processes MLIR
101 /// module and creates LLVM IR module.
102 llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
103 llvm::LLVMContext &)>
104 llvmModuleBuilder;
105
106 /// A custom function that is passed to ExecutinEngine to register symbols at
107 /// runtime.
108 llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
109 runtimeSymbolMap;
110 };
111
112 } // namespace
113
parseMLIRInput(StringRef inputFilename,MLIRContext * context)114 static OwningOpRef<ModuleOp> parseMLIRInput(StringRef inputFilename,
115 MLIRContext *context) {
116 // Set up the input file.
117 std::string errorMessage;
118 auto file = openInputFile(inputFilename, &errorMessage);
119 if (!file) {
120 llvm::errs() << errorMessage << "\n";
121 return nullptr;
122 }
123
124 llvm::SourceMgr sourceMgr;
125 sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
126 return parseSourceFile<ModuleOp>(sourceMgr, context);
127 }
128
makeStringError(const Twine & message)129 static inline Error makeStringError(const Twine &message) {
130 return llvm::make_error<llvm::StringError>(message.str(),
131 llvm::inconvertibleErrorCode());
132 }
133
getCommandLineOptLevel(Options & options)134 static Optional<unsigned> getCommandLineOptLevel(Options &options) {
135 Optional<unsigned> optLevel;
136 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
137 options.optO0, options.optO1, options.optO2, options.optO3};
138
139 // Determine if there is an optimization flag present.
140 for (unsigned j = 0; j < 4; ++j) {
141 auto &flag = optFlags[j].get();
142 if (flag) {
143 optLevel = j;
144 break;
145 }
146 }
147 return optLevel;
148 }
149
150 // JIT-compile the given module and run "entryPoint" with "args" as arguments.
compileAndExecute(Options & options,ModuleOp module,StringRef entryPoint,CompileAndExecuteConfig config,void ** args)151 static Error compileAndExecute(Options &options, ModuleOp module,
152 StringRef entryPoint,
153 CompileAndExecuteConfig config, void **args) {
154 Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel;
155 if (auto clOptLevel = getCommandLineOptLevel(options))
156 jitCodeGenOptLevel = static_cast<llvm::CodeGenOpt::Level>(*clOptLevel);
157
158 // If shared library implements custom mlir-runner library init and destroy
159 // functions, we'll use them to register the library with the execution
160 // engine. Otherwise we'll pass library directly to the execution engine.
161 SmallVector<SmallString<256>, 4> libPaths;
162
163 // Use absolute library path so that gdb can find the symbol table.
164 transform(
165 options.clSharedLibs, std::back_inserter(libPaths),
166 [](std::string libPath) {
167 SmallString<256> absPath(libPath.begin(), libPath.end());
168 cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath)));
169 return absPath;
170 });
171
172 // Libraries that we'll pass to the ExecutionEngine for loading.
173 SmallVector<StringRef, 4> executionEngineLibs;
174
175 using MlirRunnerInitFn = void (*)(llvm::StringMap<void *> &);
176 using MlirRunnerDestroyFn = void (*)();
177
178 llvm::StringMap<void *> exportSymbols;
179 SmallVector<MlirRunnerDestroyFn> destroyFns;
180
181 // Handle libraries that do support mlir-runner init/destroy callbacks.
182 for (auto &libPath : libPaths) {
183 auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str());
184 void *initSym = lib.getAddressOfSymbol("__mlir_runner_init");
185 void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy");
186
187 // Library does not support mlir runner, load it with ExecutionEngine.
188 if (!initSym || !destroySim) {
189 executionEngineLibs.push_back(libPath);
190 continue;
191 }
192
193 auto initFn = reinterpret_cast<MlirRunnerInitFn>(initSym);
194 initFn(exportSymbols);
195
196 auto destroyFn = reinterpret_cast<MlirRunnerDestroyFn>(destroySim);
197 destroyFns.push_back(destroyFn);
198 }
199
200 // Build a runtime symbol map from the config and exported symbols.
201 auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
202 auto symbolMap = config.runtimeSymbolMap ? config.runtimeSymbolMap(interner)
203 : llvm::orc::SymbolMap();
204 for (auto &exportSymbol : exportSymbols)
205 symbolMap[interner(exportSymbol.getKey())] =
206 llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue());
207 return symbolMap;
208 };
209
210 mlir::ExecutionEngineOptions engineOptions;
211 engineOptions.llvmModuleBuilder = config.llvmModuleBuilder;
212 if (config.transformer)
213 engineOptions.transformer = config.transformer;
214 engineOptions.jitCodeGenOptLevel = jitCodeGenOptLevel;
215 engineOptions.sharedLibPaths = executionEngineLibs;
216 engineOptions.enableObjectCache = true;
217 auto expectedEngine = mlir::ExecutionEngine::create(module, engineOptions);
218 if (!expectedEngine)
219 return expectedEngine.takeError();
220
221 auto engine = std::move(*expectedEngine);
222 engine->registerSymbols(runtimeSymbolMap);
223
224 auto expectedFPtr = engine->lookupPacked(entryPoint);
225 if (!expectedFPtr)
226 return expectedFPtr.takeError();
227
228 if (options.dumpObjectFile)
229 engine->dumpToObjectFile(options.objectFilename.empty()
230 ? options.inputFilename + ".o"
231 : options.objectFilename);
232
233 void (*fptr)(void **) = *expectedFPtr;
234 (*fptr)(args);
235
236 // Run all dynamic library destroy callbacks to prepare for the shutdown.
237 llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); });
238
239 return Error::success();
240 }
241
compileAndExecuteVoidFunction(Options & options,ModuleOp module,StringRef entryPoint,CompileAndExecuteConfig config)242 static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module,
243 StringRef entryPoint,
244 CompileAndExecuteConfig config) {
245 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
246 if (!mainFunction || mainFunction.empty())
247 return makeStringError("entry point not found");
248 void *empty = nullptr;
249 return compileAndExecute(options, module, entryPoint, std::move(config),
250 &empty);
251 }
252
253 template <typename Type>
254 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction);
255 template <>
checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction)256 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
257 auto resultType = mainFunction.getFunctionType()
258 .cast<LLVM::LLVMFunctionType>()
259 .getReturnType()
260 .dyn_cast<IntegerType>();
261 if (!resultType || resultType.getWidth() != 32)
262 return makeStringError("only single i32 function result supported");
263 return Error::success();
264 }
265 template <>
checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction)266 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
267 auto resultType = mainFunction.getFunctionType()
268 .cast<LLVM::LLVMFunctionType>()
269 .getReturnType()
270 .dyn_cast<IntegerType>();
271 if (!resultType || resultType.getWidth() != 64)
272 return makeStringError("only single i64 function result supported");
273 return Error::success();
274 }
275 template <>
checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction)276 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
277 if (!mainFunction.getFunctionType()
278 .cast<LLVM::LLVMFunctionType>()
279 .getReturnType()
280 .isa<Float32Type>())
281 return makeStringError("only single f32 function result supported");
282 return Error::success();
283 }
284 template <typename Type>
compileAndExecuteSingleReturnFunction(Options & options,ModuleOp module,StringRef entryPoint,CompileAndExecuteConfig config)285 Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module,
286 StringRef entryPoint,
287 CompileAndExecuteConfig config) {
288 auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
289 if (!mainFunction || mainFunction.isExternal())
290 return makeStringError("entry point not found");
291
292 if (mainFunction.getFunctionType()
293 .cast<LLVM::LLVMFunctionType>()
294 .getNumParams() != 0)
295 return makeStringError("function inputs not supported");
296
297 if (Error error = checkCompatibleReturnType<Type>(mainFunction))
298 return error;
299
300 Type res;
301 struct {
302 void *data;
303 } data;
304 data.data = &res;
305 if (auto error = compileAndExecute(options, module, entryPoint,
306 std::move(config), (void **)&data))
307 return error;
308
309 // Intentional printing of the output so we can test.
310 llvm::outs() << res << '\n';
311
312 return Error::success();
313 }
314
315 /// Entry point for all CPU runners. Expects the common argc/argv arguments for
316 /// standard C++ main functions.
JitRunnerMain(int argc,char ** argv,const DialectRegistry & registry,JitRunnerConfig config)317 int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry ®istry,
318 JitRunnerConfig config) {
319 // Create the options struct containing the command line options for the
320 // runner. This must come before the command line options are parsed.
321 Options options;
322 llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
323
324 if (options.hostSupportsJit) {
325 auto J = llvm::orc::LLJITBuilder().create();
326 if (J)
327 llvm::outs() << "true\n";
328 else {
329 llvm::consumeError(J.takeError());
330 llvm::outs() << "false\n";
331 }
332 return 0;
333 }
334
335 Optional<unsigned> optLevel = getCommandLineOptLevel(options);
336 SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
337 options.optO0, options.optO1, options.optO2, options.optO3};
338
339 MLIRContext context(registry);
340
341 auto m = parseMLIRInput(options.inputFilename, &context);
342 if (!m) {
343 llvm::errs() << "could not parse the input IR\n";
344 return 1;
345 }
346
347 if (config.mlirTransformer)
348 if (failed(config.mlirTransformer(m.get())))
349 return EXIT_FAILURE;
350
351 auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
352 if (!tmBuilderOrError) {
353 llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
354 return EXIT_FAILURE;
355 }
356 auto tmOrError = tmBuilderOrError->createTargetMachine();
357 if (!tmOrError) {
358 llvm::errs() << "Failed to create a TargetMachine for the host\n";
359 return EXIT_FAILURE;
360 }
361
362 CompileAndExecuteConfig compileAndExecuteConfig;
363 if (optLevel) {
364 compileAndExecuteConfig.transformer = mlir::makeOptimizingTransformer(
365 *optLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get());
366 }
367 compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder;
368 compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap;
369
370 // Get the function used to compile and execute the module.
371 using CompileAndExecuteFnT =
372 Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig);
373 auto compileAndExecuteFn =
374 StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
375 .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
376 .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>)
377 .Case("f32", compileAndExecuteSingleReturnFunction<float>)
378 .Case("void", compileAndExecuteVoidFunction)
379 .Default(nullptr);
380
381 Error error = compileAndExecuteFn
382 ? compileAndExecuteFn(options, m.get(),
383 options.mainFuncName.getValue(),
384 compileAndExecuteConfig)
385 : makeStringError("unsupported function type");
386
387 int exitCode = EXIT_SUCCESS;
388 llvm::handleAllErrors(std::move(error),
389 [&exitCode](const llvm::ErrorInfoBase &info) {
390 llvm::errs() << "Error: ";
391 info.log(llvm::errs());
392 llvm::errs() << '\n';
393 exitCode = EXIT_FAILURE;
394 });
395
396 return exitCode;
397 }
398