1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a library that provides a shared implementation for command line
10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
11 // IR before JIT-compiling and executing the latter.
12 //
13 // The translation can be customized by providing an MLIR to MLIR
14 // transformation.
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/JitRunner.h"
18 
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/ExecutionEngine/ExecutionEngine.h"
21 #include "mlir/ExecutionEngine/OptUtils.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/Parser/Parser.h"
25 #include "mlir/Support/FileUtilities.h"
26 
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/LLVMContext.h"
31 #include "llvm/IR/LegacyPassNameParser.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/FileUtilities.h"
34 #include "llvm/Support/SourceMgr.h"
35 #include "llvm/Support/StringSaver.h"
36 #include "llvm/Support/ToolOutputFile.h"
37 #include <cstdint>
38 #include <numeric>
39 
40 using namespace mlir;
41 using llvm::Error;
42 
43 namespace {
44 /// This options struct prevents the need for global static initializers, and
45 /// is only initialized if the JITRunner is invoked.
46 struct Options {
47   llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional,
48                                            llvm::cl::desc("<input file>"),
49                                            llvm::cl::init("-")};
50   llvm::cl::opt<std::string> mainFuncName{
51       "e", llvm::cl::desc("The function to be called"),
52       llvm::cl::value_desc("<function name>"), llvm::cl::init("main")};
53   llvm::cl::opt<std::string> mainFuncType{
54       "entry-point-result",
55       llvm::cl::desc("Textual description of the function type to be called"),
56       llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")};
57 
58   llvm::cl::OptionCategory optFlags{"opt-like flags"};
59 
60   // CLI variables for -On options.
61   llvm::cl::opt<bool> optO0{"O0",
62                             llvm::cl::desc("Run opt passes and codegen at O0"),
63                             llvm::cl::cat(optFlags)};
64   llvm::cl::opt<bool> optO1{"O1",
65                             llvm::cl::desc("Run opt passes and codegen at O1"),
66                             llvm::cl::cat(optFlags)};
67   llvm::cl::opt<bool> optO2{"O2",
68                             llvm::cl::desc("Run opt passes and codegen at O2"),
69                             llvm::cl::cat(optFlags)};
70   llvm::cl::opt<bool> optO3{"O3",
71                             llvm::cl::desc("Run opt passes and codegen at O3"),
72                             llvm::cl::cat(optFlags)};
73 
74   llvm::cl::OptionCategory clOptionsCategory{"linking options"};
75   llvm::cl::list<std::string> clSharedLibs{
76       "shared-libs", llvm::cl::desc("Libraries to link dynamically"),
77       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
78       llvm::cl::cat(clOptionsCategory)};
79 
80   /// CLI variables for debugging.
81   llvm::cl::opt<bool> dumpObjectFile{
82       "dump-object-file",
83       llvm::cl::desc("Dump JITted-compiled object to file specified with "
84                      "-object-filename (<input file>.o by default).")};
85 
86   llvm::cl::opt<std::string> objectFilename{
87       "object-filename",
88       llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
89 };
90 
91 struct CompileAndExecuteConfig {
92   /// LLVM module transformer that is passed to ExecutionEngine.
93   std::function<llvm::Error(llvm::Module *)> transformer;
94 
95   /// A custom function that is passed to ExecutionEngine. It processes MLIR
96   /// module and creates LLVM IR module.
97   llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
98                                                    llvm::LLVMContext &)>
99       llvmModuleBuilder;
100 
101   /// A custom function that is passed to ExecutinEngine to register symbols at
102   /// runtime.
103   llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
104       runtimeSymbolMap;
105 };
106 
107 } // namespace
108 
109 static OwningOpRef<ModuleOp> parseMLIRInput(StringRef inputFilename,
110                                             MLIRContext *context) {
111   // Set up the input file.
112   std::string errorMessage;
113   auto file = openInputFile(inputFilename, &errorMessage);
114   if (!file) {
115     llvm::errs() << errorMessage << "\n";
116     return nullptr;
117   }
118 
119   llvm::SourceMgr sourceMgr;
120   sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
121   return parseSourceFile<ModuleOp>(sourceMgr, context);
122 }
123 
124 static inline Error makeStringError(const Twine &message) {
125   return llvm::make_error<llvm::StringError>(message.str(),
126                                              llvm::inconvertibleErrorCode());
127 }
128 
129 static Optional<unsigned> getCommandLineOptLevel(Options &options) {
130   Optional<unsigned> optLevel;
131   SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
132       options.optO0, options.optO1, options.optO2, options.optO3};
133 
134   // Determine if there is an optimization flag present.
135   for (unsigned j = 0; j < 4; ++j) {
136     auto &flag = optFlags[j].get();
137     if (flag) {
138       optLevel = j;
139       break;
140     }
141   }
142   return optLevel;
143 }
144 
145 // JIT-compile the given module and run "entryPoint" with "args" as arguments.
146 static Error compileAndExecute(Options &options, ModuleOp module,
147                                StringRef entryPoint,
148                                CompileAndExecuteConfig config, void **args) {
149   Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel;
150   if (auto clOptLevel = getCommandLineOptLevel(options))
151     jitCodeGenOptLevel =
152         static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue());
153 
154   // If shared library implements custom mlir-runner library init and destroy
155   // functions, we'll use them to register the library with the execution
156   // engine. Otherwise we'll pass library directly to the execution engine.
157   SmallVector<SmallString<256>, 4> libPaths;
158 
159   // Use absolute library path so that gdb can find the symbol table.
160   transform(
161       options.clSharedLibs, std::back_inserter(libPaths),
162       [](std::string libPath) {
163         SmallString<256> absPath(libPath.begin(), libPath.end());
164         cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath)));
165         return absPath;
166       });
167 
168   // Libraries that we'll pass to the ExecutionEngine for loading.
169   SmallVector<StringRef, 4> executionEngineLibs;
170 
171   using MlirRunnerInitFn = void (*)(llvm::StringMap<void *> &);
172   using MlirRunnerDestroyFn = void (*)();
173 
174   llvm::StringMap<void *> exportSymbols;
175   SmallVector<MlirRunnerDestroyFn> destroyFns;
176 
177   // Handle libraries that do support mlir-runner init/destroy callbacks.
178   for (auto &libPath : libPaths) {
179     auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str());
180     void *initSym = lib.getAddressOfSymbol("__mlir_runner_init");
181     void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy");
182 
183     // Library does not support mlir runner, load it with ExecutionEngine.
184     if (!initSym || !destroySim) {
185       executionEngineLibs.push_back(libPath);
186       continue;
187     }
188 
189     auto initFn = reinterpret_cast<MlirRunnerInitFn>(initSym);
190     initFn(exportSymbols);
191 
192     auto destroyFn = reinterpret_cast<MlirRunnerDestroyFn>(destroySim);
193     destroyFns.push_back(destroyFn);
194   }
195 
196   // Build a runtime symbol map from the config and exported symbols.
197   auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
198     auto symbolMap = config.runtimeSymbolMap ? config.runtimeSymbolMap(interner)
199                                              : llvm::orc::SymbolMap();
200     for (auto &exportSymbol : exportSymbols)
201       symbolMap[interner(exportSymbol.getKey())] =
202           llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue());
203     return symbolMap;
204   };
205 
206   mlir::ExecutionEngineOptions engineOptions;
207   engineOptions.llvmModuleBuilder = config.llvmModuleBuilder;
208   if (config.transformer)
209     engineOptions.transformer = config.transformer;
210   engineOptions.jitCodeGenOptLevel = jitCodeGenOptLevel;
211   engineOptions.sharedLibPaths = executionEngineLibs;
212   engineOptions.enableObjectCache = true;
213   auto expectedEngine = mlir::ExecutionEngine::create(module, engineOptions);
214   if (!expectedEngine)
215     return expectedEngine.takeError();
216 
217   auto engine = std::move(*expectedEngine);
218   engine->registerSymbols(runtimeSymbolMap);
219 
220   auto expectedFPtr = engine->lookupPacked(entryPoint);
221   if (!expectedFPtr)
222     return expectedFPtr.takeError();
223 
224   if (options.dumpObjectFile)
225     engine->dumpToObjectFile(options.objectFilename.empty()
226                                  ? options.inputFilename + ".o"
227                                  : options.objectFilename);
228 
229   void (*fptr)(void **) = *expectedFPtr;
230   (*fptr)(args);
231 
232   // Run all dynamic library destroy callbacks to prepare for the shutdown.
233   llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); });
234 
235   return Error::success();
236 }
237 
238 static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module,
239                                            StringRef entryPoint,
240                                            CompileAndExecuteConfig config) {
241   auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
242   if (!mainFunction || mainFunction.empty())
243     return makeStringError("entry point not found");
244   void *empty = nullptr;
245   return compileAndExecute(options, module, entryPoint, config, &empty);
246 }
247 
248 template <typename Type>
249 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction);
250 template <>
251 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
252   auto resultType = mainFunction.getFunctionType()
253                         .cast<LLVM::LLVMFunctionType>()
254                         .getReturnType()
255                         .dyn_cast<IntegerType>();
256   if (!resultType || resultType.getWidth() != 32)
257     return makeStringError("only single i32 function result supported");
258   return Error::success();
259 }
260 template <>
261 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
262   auto resultType = mainFunction.getFunctionType()
263                         .cast<LLVM::LLVMFunctionType>()
264                         .getReturnType()
265                         .dyn_cast<IntegerType>();
266   if (!resultType || resultType.getWidth() != 64)
267     return makeStringError("only single i64 function result supported");
268   return Error::success();
269 }
270 template <>
271 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
272   if (!mainFunction.getFunctionType()
273            .cast<LLVM::LLVMFunctionType>()
274            .getReturnType()
275            .isa<Float32Type>())
276     return makeStringError("only single f32 function result supported");
277   return Error::success();
278 }
279 template <typename Type>
280 Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module,
281                                             StringRef entryPoint,
282                                             CompileAndExecuteConfig config) {
283   auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
284   if (!mainFunction || mainFunction.isExternal())
285     return makeStringError("entry point not found");
286 
287   if (mainFunction.getFunctionType()
288           .cast<LLVM::LLVMFunctionType>()
289           .getNumParams() != 0)
290     return makeStringError("function inputs not supported");
291 
292   if (Error error = checkCompatibleReturnType<Type>(mainFunction))
293     return error;
294 
295   Type res;
296   struct {
297     void *data;
298   } data;
299   data.data = &res;
300   if (auto error = compileAndExecute(options, module, entryPoint, config,
301                                      (void **)&data))
302     return error;
303 
304   // Intentional printing of the output so we can test.
305   llvm::outs() << res << '\n';
306 
307   return Error::success();
308 }
309 
310 /// Entry point for all CPU runners. Expects the common argc/argv arguments for
311 /// standard C++ main functions.
312 int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry &registry,
313                         JitRunnerConfig config) {
314   // Create the options struct containing the command line options for the
315   // runner. This must come before the command line options are parsed.
316   Options options;
317   llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
318 
319   Optional<unsigned> optLevel = getCommandLineOptLevel(options);
320   SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
321       options.optO0, options.optO1, options.optO2, options.optO3};
322 
323   MLIRContext context(registry);
324 
325   auto m = parseMLIRInput(options.inputFilename, &context);
326   if (!m) {
327     llvm::errs() << "could not parse the input IR\n";
328     return 1;
329   }
330 
331   if (config.mlirTransformer)
332     if (failed(config.mlirTransformer(m.get())))
333       return EXIT_FAILURE;
334 
335   auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
336   if (!tmBuilderOrError) {
337     llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
338     return EXIT_FAILURE;
339   }
340   auto tmOrError = tmBuilderOrError->createTargetMachine();
341   if (!tmOrError) {
342     llvm::errs() << "Failed to create a TargetMachine for the host\n";
343     return EXIT_FAILURE;
344   }
345 
346   CompileAndExecuteConfig compileAndExecuteConfig;
347   if (optLevel) {
348     compileAndExecuteConfig.transformer = mlir::makeOptimizingTransformer(
349         *optLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get());
350   }
351   compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder;
352   compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap;
353 
354   // Get the function used to compile and execute the module.
355   using CompileAndExecuteFnT =
356       Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig);
357   auto compileAndExecuteFn =
358       StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
359           .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
360           .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>)
361           .Case("f32", compileAndExecuteSingleReturnFunction<float>)
362           .Case("void", compileAndExecuteVoidFunction)
363           .Default(nullptr);
364 
365   Error error = compileAndExecuteFn
366                     ? compileAndExecuteFn(options, m.get(),
367                                           options.mainFuncName.getValue(),
368                                           compileAndExecuteConfig)
369                     : makeStringError("unsupported function type");
370 
371   int exitCode = EXIT_SUCCESS;
372   llvm::handleAllErrors(std::move(error),
373                         [&exitCode](const llvm::ErrorInfoBase &info) {
374                           llvm::errs() << "Error: ";
375                           info.log(llvm::errs());
376                           llvm::errs() << '\n';
377                           exitCode = EXIT_FAILURE;
378                         });
379 
380   return exitCode;
381 }
382