1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a library that provides a shared implementation for command line
10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
11 // IR before JIT-compiling and executing the latter.
12 //
13 // The translation can be customized by providing an MLIR to MLIR
14 // transformation.
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/JitRunner.h"
18 
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/ExecutionEngine/ExecutionEngine.h"
21 #include "mlir/ExecutionEngine/OptUtils.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/Parser/Parser.h"
25 #include "mlir/Support/FileUtilities.h"
26 
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/LLVMContext.h"
31 #include "llvm/IR/LegacyPassNameParser.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/FileUtilities.h"
34 #include "llvm/Support/SourceMgr.h"
35 #include "llvm/Support/StringSaver.h"
36 #include "llvm/Support/ToolOutputFile.h"
37 #include <cstdint>
38 #include <numeric>
39 
40 using namespace mlir;
41 using llvm::Error;
42 
43 namespace {
44 /// This options struct prevents the need for global static initializers, and
45 /// is only initialized if the JITRunner is invoked.
46 struct Options {
47   llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional,
48                                            llvm::cl::desc("<input file>"),
49                                            llvm::cl::init("-")};
50   llvm::cl::opt<std::string> mainFuncName{
51       "e", llvm::cl::desc("The function to be called"),
52       llvm::cl::value_desc("<function name>"), llvm::cl::init("main")};
53   llvm::cl::opt<std::string> mainFuncType{
54       "entry-point-result",
55       llvm::cl::desc("Textual description of the function type to be called"),
56       llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")};
57 
58   llvm::cl::OptionCategory optFlags{"opt-like flags"};
59 
60   // CLI list of pass information
61   llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser> llvmPasses{
62       llvm::cl::desc("LLVM optimizing passes to run"), llvm::cl::cat(optFlags)};
63 
64   // CLI variables for -On options.
65   llvm::cl::opt<bool> optO0{"O0",
66                             llvm::cl::desc("Run opt passes and codegen at O0"),
67                             llvm::cl::cat(optFlags)};
68   llvm::cl::opt<bool> optO1{"O1",
69                             llvm::cl::desc("Run opt passes and codegen at O1"),
70                             llvm::cl::cat(optFlags)};
71   llvm::cl::opt<bool> optO2{"O2",
72                             llvm::cl::desc("Run opt passes and codegen at O2"),
73                             llvm::cl::cat(optFlags)};
74   llvm::cl::opt<bool> optO3{"O3",
75                             llvm::cl::desc("Run opt passes and codegen at O3"),
76                             llvm::cl::cat(optFlags)};
77 
78   llvm::cl::OptionCategory clOptionsCategory{"linking options"};
79   llvm::cl::list<std::string> clSharedLibs{
80       "shared-libs", llvm::cl::desc("Libraries to link dynamically"),
81       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
82       llvm::cl::cat(clOptionsCategory)};
83 
84   /// CLI variables for debugging.
85   llvm::cl::opt<bool> dumpObjectFile{
86       "dump-object-file",
87       llvm::cl::desc("Dump JITted-compiled object to file specified with "
88                      "-object-filename (<input file>.o by default).")};
89 
90   llvm::cl::opt<std::string> objectFilename{
91       "object-filename",
92       llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
93 };
94 
95 struct CompileAndExecuteConfig {
96   /// LLVM module transformer that is passed to ExecutionEngine.
97   llvm::function_ref<llvm::Error(llvm::Module *)> transformer;
98 
99   /// A custom function that is passed to ExecutionEngine. It processes MLIR
100   /// module and creates LLVM IR module.
101   llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
102                                                    llvm::LLVMContext &)>
103       llvmModuleBuilder;
104 
105   /// A custom function that is passed to ExecutinEngine to register symbols at
106   /// runtime.
107   llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
108       runtimeSymbolMap;
109 };
110 
111 } // namespace
112 
113 static OwningOpRef<ModuleOp> parseMLIRInput(StringRef inputFilename,
114                                             MLIRContext *context) {
115   // Set up the input file.
116   std::string errorMessage;
117   auto file = openInputFile(inputFilename, &errorMessage);
118   if (!file) {
119     llvm::errs() << errorMessage << "\n";
120     return nullptr;
121   }
122 
123   llvm::SourceMgr sourceMgr;
124   sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
125   return parseSourceFile<ModuleOp>(sourceMgr, context);
126 }
127 
128 static inline Error makeStringError(const Twine &message) {
129   return llvm::make_error<llvm::StringError>(message.str(),
130                                              llvm::inconvertibleErrorCode());
131 }
132 
133 static Optional<unsigned> getCommandLineOptLevel(Options &options) {
134   Optional<unsigned> optLevel;
135   SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
136       options.optO0, options.optO1, options.optO2, options.optO3};
137 
138   // Determine if there is an optimization flag present.
139   for (unsigned j = 0; j < 4; ++j) {
140     auto &flag = optFlags[j].get();
141     if (flag) {
142       optLevel = j;
143       break;
144     }
145   }
146   return optLevel;
147 }
148 
149 // JIT-compile the given module and run "entryPoint" with "args" as arguments.
150 static Error compileAndExecute(Options &options, ModuleOp module,
151                                StringRef entryPoint,
152                                CompileAndExecuteConfig config, void **args) {
153   Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel;
154   if (auto clOptLevel = getCommandLineOptLevel(options))
155     jitCodeGenOptLevel =
156         static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue());
157 
158   // If shared library implements custom mlir-runner library init and destroy
159   // functions, we'll use them to register the library with the execution
160   // engine. Otherwise we'll pass library directly to the execution engine.
161   SmallVector<SmallString<256>, 4> libPaths;
162 
163   // Use absolute library path so that gdb can find the symbol table.
164   transform(
165       options.clSharedLibs, std::back_inserter(libPaths),
166       [](std::string libPath) {
167         SmallString<256> absPath(libPath.begin(), libPath.end());
168         cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath)));
169         return absPath;
170       });
171 
172   // Libraries that we'll pass to the ExecutionEngine for loading.
173   SmallVector<StringRef, 4> executionEngineLibs;
174 
175   using MlirRunnerInitFn = void (*)(llvm::StringMap<void *> &);
176   using MlirRunnerDestroyFn = void (*)();
177 
178   llvm::StringMap<void *> exportSymbols;
179   SmallVector<MlirRunnerDestroyFn> destroyFns;
180 
181   // Handle libraries that do support mlir-runner init/destroy callbacks.
182   for (auto &libPath : libPaths) {
183     auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str());
184     void *initSym = lib.getAddressOfSymbol("__mlir_runner_init");
185     void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy");
186 
187     // Library does not support mlir runner, load it with ExecutionEngine.
188     if (!initSym || !destroySim) {
189       executionEngineLibs.push_back(libPath);
190       continue;
191     }
192 
193     auto initFn = reinterpret_cast<MlirRunnerInitFn>(initSym);
194     initFn(exportSymbols);
195 
196     auto destroyFn = reinterpret_cast<MlirRunnerDestroyFn>(destroySim);
197     destroyFns.push_back(destroyFn);
198   }
199 
200   // Build a runtime symbol map from the config and exported symbols.
201   auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
202     auto symbolMap = config.runtimeSymbolMap ? config.runtimeSymbolMap(interner)
203                                              : llvm::orc::SymbolMap();
204     for (auto &exportSymbol : exportSymbols)
205       symbolMap[interner(exportSymbol.getKey())] =
206           llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue());
207     return symbolMap;
208   };
209 
210   mlir::ExecutionEngineOptions engineOptions;
211   engineOptions.llvmModuleBuilder = config.llvmModuleBuilder;
212   engineOptions.transformer = config.transformer;
213   engineOptions.jitCodeGenOptLevel = jitCodeGenOptLevel;
214   engineOptions.sharedLibPaths = executionEngineLibs;
215   engineOptions.enableObjectCache = true;
216   auto expectedEngine = mlir::ExecutionEngine::create(module, engineOptions);
217   if (!expectedEngine)
218     return expectedEngine.takeError();
219 
220   auto engine = std::move(*expectedEngine);
221   engine->registerSymbols(runtimeSymbolMap);
222 
223   auto expectedFPtr = engine->lookupPacked(entryPoint);
224   if (!expectedFPtr)
225     return expectedFPtr.takeError();
226 
227   if (options.dumpObjectFile)
228     engine->dumpToObjectFile(options.objectFilename.empty()
229                                  ? options.inputFilename + ".o"
230                                  : options.objectFilename);
231 
232   void (*fptr)(void **) = *expectedFPtr;
233   (*fptr)(args);
234 
235   // Run all dynamic library destroy callbacks to prepare for the shutdown.
236   llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); });
237 
238   return Error::success();
239 }
240 
241 static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module,
242                                            StringRef entryPoint,
243                                            CompileAndExecuteConfig config) {
244   auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
245   if (!mainFunction || mainFunction.empty())
246     return makeStringError("entry point not found");
247   void *empty = nullptr;
248   return compileAndExecute(options, module, entryPoint, config, &empty);
249 }
250 
251 template <typename Type>
252 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction);
253 template <>
254 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
255   auto resultType = mainFunction.getFunctionType()
256                         .cast<LLVM::LLVMFunctionType>()
257                         .getReturnType()
258                         .dyn_cast<IntegerType>();
259   if (!resultType || resultType.getWidth() != 32)
260     return makeStringError("only single i32 function result supported");
261   return Error::success();
262 }
263 template <>
264 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
265   auto resultType = mainFunction.getFunctionType()
266                         .cast<LLVM::LLVMFunctionType>()
267                         .getReturnType()
268                         .dyn_cast<IntegerType>();
269   if (!resultType || resultType.getWidth() != 64)
270     return makeStringError("only single i64 function result supported");
271   return Error::success();
272 }
273 template <>
274 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
275   if (!mainFunction.getFunctionType()
276            .cast<LLVM::LLVMFunctionType>()
277            .getReturnType()
278            .isa<Float32Type>())
279     return makeStringError("only single f32 function result supported");
280   return Error::success();
281 }
282 template <typename Type>
283 Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module,
284                                             StringRef entryPoint,
285                                             CompileAndExecuteConfig config) {
286   auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
287   if (!mainFunction || mainFunction.isExternal())
288     return makeStringError("entry point not found");
289 
290   if (mainFunction.getFunctionType()
291           .cast<LLVM::LLVMFunctionType>()
292           .getNumParams() != 0)
293     return makeStringError("function inputs not supported");
294 
295   if (Error error = checkCompatibleReturnType<Type>(mainFunction))
296     return error;
297 
298   Type res;
299   struct {
300     void *data;
301   } data;
302   data.data = &res;
303   if (auto error = compileAndExecute(options, module, entryPoint, config,
304                                      (void **)&data))
305     return error;
306 
307   // Intentional printing of the output so we can test.
308   llvm::outs() << res << '\n';
309 
310   return Error::success();
311 }
312 
313 /// Entry point for all CPU runners. Expects the common argc/argv arguments for
314 /// standard C++ main functions.
315 int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry &registry,
316                         JitRunnerConfig config) {
317   // Create the options struct containing the command line options for the
318   // runner. This must come before the command line options are parsed.
319   Options options;
320   llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
321 
322   Optional<unsigned> optLevel = getCommandLineOptLevel(options);
323   SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
324       options.optO0, options.optO1, options.optO2, options.optO3};
325   unsigned optCLIPosition = 0;
326   // Determine if there is an optimization flag present, and its CLI position
327   // (optCLIPosition).
328   for (unsigned j = 0; j < 4; ++j) {
329     auto &flag = optFlags[j].get();
330     if (flag) {
331       optCLIPosition = flag.getPosition();
332       break;
333     }
334   }
335   // Generate vector of pass information, plus the index at which we should
336   // insert any optimization passes in that vector (optPosition).
337   SmallVector<const llvm::PassInfo *, 4> passes;
338   unsigned optPosition = 0;
339   for (unsigned i = 0, e = options.llvmPasses.size(); i < e; ++i) {
340     passes.push_back(options.llvmPasses[i]);
341     if (optCLIPosition < options.llvmPasses.getPosition(i)) {
342       optPosition = i;
343       optCLIPosition = UINT_MAX; // To ensure we never insert again
344     }
345   }
346 
347   MLIRContext context(registry);
348 
349   auto m = parseMLIRInput(options.inputFilename, &context);
350   if (!m) {
351     llvm::errs() << "could not parse the input IR\n";
352     return 1;
353   }
354 
355   if (config.mlirTransformer)
356     if (failed(config.mlirTransformer(m.get())))
357       return EXIT_FAILURE;
358 
359   auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
360   if (!tmBuilderOrError) {
361     llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
362     return EXIT_FAILURE;
363   }
364   auto tmOrError = tmBuilderOrError->createTargetMachine();
365   if (!tmOrError) {
366     llvm::errs() << "Failed to create a TargetMachine for the host\n";
367     return EXIT_FAILURE;
368   }
369 
370   auto transformer = mlir::makeLLVMPassesTransformer(
371       passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition);
372 
373   CompileAndExecuteConfig compileAndExecuteConfig;
374   compileAndExecuteConfig.transformer = transformer;
375   compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder;
376   compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap;
377 
378   // Get the function used to compile and execute the module.
379   using CompileAndExecuteFnT =
380       Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig);
381   auto compileAndExecuteFn =
382       StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
383           .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
384           .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>)
385           .Case("f32", compileAndExecuteSingleReturnFunction<float>)
386           .Case("void", compileAndExecuteVoidFunction)
387           .Default(nullptr);
388 
389   Error error = compileAndExecuteFn
390                     ? compileAndExecuteFn(options, m.get(),
391                                           options.mainFuncName.getValue(),
392                                           compileAndExecuteConfig)
393                     : makeStringError("unsupported function type");
394 
395   int exitCode = EXIT_SUCCESS;
396   llvm::handleAllErrors(std::move(error),
397                         [&exitCode](const llvm::ErrorInfoBase &info) {
398                           llvm::errs() << "Error: ";
399                           info.log(llvm::errs());
400                           llvm::errs() << '\n';
401                           exitCode = EXIT_FAILURE;
402                         });
403 
404   return exitCode;
405 }
406