1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a library that provides a shared implementation for command line
10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
11 // IR before JIT-compiling and executing the latter.
12 //
13 // The translation can be customized by providing an MLIR to MLIR
14 // transformation.
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/JitRunner.h"
18 
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/ExecutionEngine/ExecutionEngine.h"
21 #include "mlir/ExecutionEngine/OptUtils.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/Parser/Parser.h"
25 #include "mlir/Support/FileUtilities.h"
26 
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
29 #include "llvm/IR/IRBuilder.h"
30 #include "llvm/IR/LLVMContext.h"
31 #include "llvm/IR/LegacyPassNameParser.h"
32 #include "llvm/Support/CommandLine.h"
33 #include "llvm/Support/FileUtilities.h"
34 #include "llvm/Support/SourceMgr.h"
35 #include "llvm/Support/StringSaver.h"
36 #include "llvm/Support/ToolOutputFile.h"
37 #include <cstdint>
38 #include <numeric>
39 #include <utility>
40 
41 using namespace mlir;
42 using llvm::Error;
43 
44 namespace {
45 /// This options struct prevents the need for global static initializers, and
46 /// is only initialized if the JITRunner is invoked.
47 struct Options {
48   llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional,
49                                            llvm::cl::desc("<input file>"),
50                                            llvm::cl::init("-")};
51   llvm::cl::opt<std::string> mainFuncName{
52       "e", llvm::cl::desc("The function to be called"),
53       llvm::cl::value_desc("<function name>"), llvm::cl::init("main")};
54   llvm::cl::opt<std::string> mainFuncType{
55       "entry-point-result",
56       llvm::cl::desc("Textual description of the function type to be called"),
57       llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")};
58 
59   llvm::cl::OptionCategory optFlags{"opt-like flags"};
60 
61   // CLI variables for -On options.
62   llvm::cl::opt<bool> optO0{"O0",
63                             llvm::cl::desc("Run opt passes and codegen at O0"),
64                             llvm::cl::cat(optFlags)};
65   llvm::cl::opt<bool> optO1{"O1",
66                             llvm::cl::desc("Run opt passes and codegen at O1"),
67                             llvm::cl::cat(optFlags)};
68   llvm::cl::opt<bool> optO2{"O2",
69                             llvm::cl::desc("Run opt passes and codegen at O2"),
70                             llvm::cl::cat(optFlags)};
71   llvm::cl::opt<bool> optO3{"O3",
72                             llvm::cl::desc("Run opt passes and codegen at O3"),
73                             llvm::cl::cat(optFlags)};
74 
75   llvm::cl::OptionCategory clOptionsCategory{"linking options"};
76   llvm::cl::list<std::string> clSharedLibs{
77       "shared-libs", llvm::cl::desc("Libraries to link dynamically"),
78       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
79       llvm::cl::cat(clOptionsCategory)};
80 
81   /// CLI variables for debugging.
82   llvm::cl::opt<bool> dumpObjectFile{
83       "dump-object-file",
84       llvm::cl::desc("Dump JITted-compiled object to file specified with "
85                      "-object-filename (<input file>.o by default).")};
86 
87   llvm::cl::opt<std::string> objectFilename{
88       "object-filename",
89       llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
90 };
91 
92 struct CompileAndExecuteConfig {
93   /// LLVM module transformer that is passed to ExecutionEngine.
94   std::function<llvm::Error(llvm::Module *)> transformer;
95 
96   /// A custom function that is passed to ExecutionEngine. It processes MLIR
97   /// module and creates LLVM IR module.
98   llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
99                                                    llvm::LLVMContext &)>
100       llvmModuleBuilder;
101 
102   /// A custom function that is passed to ExecutinEngine to register symbols at
103   /// runtime.
104   llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
105       runtimeSymbolMap;
106 };
107 
108 } // namespace
109 
110 static OwningOpRef<ModuleOp> parseMLIRInput(StringRef inputFilename,
111                                             MLIRContext *context) {
112   // Set up the input file.
113   std::string errorMessage;
114   auto file = openInputFile(inputFilename, &errorMessage);
115   if (!file) {
116     llvm::errs() << errorMessage << "\n";
117     return nullptr;
118   }
119 
120   llvm::SourceMgr sourceMgr;
121   sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
122   return parseSourceFile<ModuleOp>(sourceMgr, context);
123 }
124 
125 static inline Error makeStringError(const Twine &message) {
126   return llvm::make_error<llvm::StringError>(message.str(),
127                                              llvm::inconvertibleErrorCode());
128 }
129 
130 static Optional<unsigned> getCommandLineOptLevel(Options &options) {
131   Optional<unsigned> optLevel;
132   SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
133       options.optO0, options.optO1, options.optO2, options.optO3};
134 
135   // Determine if there is an optimization flag present.
136   for (unsigned j = 0; j < 4; ++j) {
137     auto &flag = optFlags[j].get();
138     if (flag) {
139       optLevel = j;
140       break;
141     }
142   }
143   return optLevel;
144 }
145 
146 // JIT-compile the given module and run "entryPoint" with "args" as arguments.
147 static Error compileAndExecute(Options &options, ModuleOp module,
148                                StringRef entryPoint,
149                                CompileAndExecuteConfig config, void **args) {
150   Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel;
151   if (auto clOptLevel = getCommandLineOptLevel(options))
152     jitCodeGenOptLevel =
153         static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue());
154 
155   // If shared library implements custom mlir-runner library init and destroy
156   // functions, we'll use them to register the library with the execution
157   // engine. Otherwise we'll pass library directly to the execution engine.
158   SmallVector<SmallString<256>, 4> libPaths;
159 
160   // Use absolute library path so that gdb can find the symbol table.
161   transform(
162       options.clSharedLibs, std::back_inserter(libPaths),
163       [](std::string libPath) {
164         SmallString<256> absPath(libPath.begin(), libPath.end());
165         cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath)));
166         return absPath;
167       });
168 
169   // Libraries that we'll pass to the ExecutionEngine for loading.
170   SmallVector<StringRef, 4> executionEngineLibs;
171 
172   using MlirRunnerInitFn = void (*)(llvm::StringMap<void *> &);
173   using MlirRunnerDestroyFn = void (*)();
174 
175   llvm::StringMap<void *> exportSymbols;
176   SmallVector<MlirRunnerDestroyFn> destroyFns;
177 
178   // Handle libraries that do support mlir-runner init/destroy callbacks.
179   for (auto &libPath : libPaths) {
180     auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str());
181     void *initSym = lib.getAddressOfSymbol("__mlir_runner_init");
182     void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy");
183 
184     // Library does not support mlir runner, load it with ExecutionEngine.
185     if (!initSym || !destroySim) {
186       executionEngineLibs.push_back(libPath);
187       continue;
188     }
189 
190     auto initFn = reinterpret_cast<MlirRunnerInitFn>(initSym);
191     initFn(exportSymbols);
192 
193     auto destroyFn = reinterpret_cast<MlirRunnerDestroyFn>(destroySim);
194     destroyFns.push_back(destroyFn);
195   }
196 
197   // Build a runtime symbol map from the config and exported symbols.
198   auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
199     auto symbolMap = config.runtimeSymbolMap ? config.runtimeSymbolMap(interner)
200                                              : llvm::orc::SymbolMap();
201     for (auto &exportSymbol : exportSymbols)
202       symbolMap[interner(exportSymbol.getKey())] =
203           llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue());
204     return symbolMap;
205   };
206 
207   mlir::ExecutionEngineOptions engineOptions;
208   engineOptions.llvmModuleBuilder = config.llvmModuleBuilder;
209   if (config.transformer)
210     engineOptions.transformer = config.transformer;
211   engineOptions.jitCodeGenOptLevel = jitCodeGenOptLevel;
212   engineOptions.sharedLibPaths = executionEngineLibs;
213   engineOptions.enableObjectCache = true;
214   auto expectedEngine = mlir::ExecutionEngine::create(module, engineOptions);
215   if (!expectedEngine)
216     return expectedEngine.takeError();
217 
218   auto engine = std::move(*expectedEngine);
219   engine->registerSymbols(runtimeSymbolMap);
220 
221   auto expectedFPtr = engine->lookupPacked(entryPoint);
222   if (!expectedFPtr)
223     return expectedFPtr.takeError();
224 
225   if (options.dumpObjectFile)
226     engine->dumpToObjectFile(options.objectFilename.empty()
227                                  ? options.inputFilename + ".o"
228                                  : options.objectFilename);
229 
230   void (*fptr)(void **) = *expectedFPtr;
231   (*fptr)(args);
232 
233   // Run all dynamic library destroy callbacks to prepare for the shutdown.
234   llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); });
235 
236   return Error::success();
237 }
238 
239 static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module,
240                                            StringRef entryPoint,
241                                            CompileAndExecuteConfig config) {
242   auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
243   if (!mainFunction || mainFunction.empty())
244     return makeStringError("entry point not found");
245   void *empty = nullptr;
246   return compileAndExecute(options, module, entryPoint, std::move(config),
247                            &empty);
248 }
249 
250 template <typename Type>
251 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction);
252 template <>
253 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
254   auto resultType = mainFunction.getFunctionType()
255                         .cast<LLVM::LLVMFunctionType>()
256                         .getReturnType()
257                         .dyn_cast<IntegerType>();
258   if (!resultType || resultType.getWidth() != 32)
259     return makeStringError("only single i32 function result supported");
260   return Error::success();
261 }
262 template <>
263 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
264   auto resultType = mainFunction.getFunctionType()
265                         .cast<LLVM::LLVMFunctionType>()
266                         .getReturnType()
267                         .dyn_cast<IntegerType>();
268   if (!resultType || resultType.getWidth() != 64)
269     return makeStringError("only single i64 function result supported");
270   return Error::success();
271 }
272 template <>
273 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
274   if (!mainFunction.getFunctionType()
275            .cast<LLVM::LLVMFunctionType>()
276            .getReturnType()
277            .isa<Float32Type>())
278     return makeStringError("only single f32 function result supported");
279   return Error::success();
280 }
281 template <typename Type>
282 Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module,
283                                             StringRef entryPoint,
284                                             CompileAndExecuteConfig config) {
285   auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
286   if (!mainFunction || mainFunction.isExternal())
287     return makeStringError("entry point not found");
288 
289   if (mainFunction.getFunctionType()
290           .cast<LLVM::LLVMFunctionType>()
291           .getNumParams() != 0)
292     return makeStringError("function inputs not supported");
293 
294   if (Error error = checkCompatibleReturnType<Type>(mainFunction))
295     return error;
296 
297   Type res;
298   struct {
299     void *data;
300   } data;
301   data.data = &res;
302   if (auto error = compileAndExecute(options, module, entryPoint,
303                                      std::move(config), (void **)&data))
304     return error;
305 
306   // Intentional printing of the output so we can test.
307   llvm::outs() << res << '\n';
308 
309   return Error::success();
310 }
311 
312 /// Entry point for all CPU runners. Expects the common argc/argv arguments for
313 /// standard C++ main functions.
314 int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry &registry,
315                         JitRunnerConfig config) {
316   // Create the options struct containing the command line options for the
317   // runner. This must come before the command line options are parsed.
318   Options options;
319   llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
320 
321   Optional<unsigned> optLevel = getCommandLineOptLevel(options);
322   SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
323       options.optO0, options.optO1, options.optO2, options.optO3};
324 
325   MLIRContext context(registry);
326 
327   auto m = parseMLIRInput(options.inputFilename, &context);
328   if (!m) {
329     llvm::errs() << "could not parse the input IR\n";
330     return 1;
331   }
332 
333   if (config.mlirTransformer)
334     if (failed(config.mlirTransformer(m.get())))
335       return EXIT_FAILURE;
336 
337   auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
338   if (!tmBuilderOrError) {
339     llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
340     return EXIT_FAILURE;
341   }
342   auto tmOrError = tmBuilderOrError->createTargetMachine();
343   if (!tmOrError) {
344     llvm::errs() << "Failed to create a TargetMachine for the host\n";
345     return EXIT_FAILURE;
346   }
347 
348   CompileAndExecuteConfig compileAndExecuteConfig;
349   if (optLevel) {
350     compileAndExecuteConfig.transformer = mlir::makeOptimizingTransformer(
351         *optLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get());
352   }
353   compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder;
354   compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap;
355 
356   // Get the function used to compile and execute the module.
357   using CompileAndExecuteFnT =
358       Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig);
359   auto compileAndExecuteFn =
360       StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
361           .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
362           .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>)
363           .Case("f32", compileAndExecuteSingleReturnFunction<float>)
364           .Case("void", compileAndExecuteVoidFunction)
365           .Default(nullptr);
366 
367   Error error = compileAndExecuteFn
368                     ? compileAndExecuteFn(options, m.get(),
369                                           options.mainFuncName.getValue(),
370                                           compileAndExecuteConfig)
371                     : makeStringError("unsupported function type");
372 
373   int exitCode = EXIT_SUCCESS;
374   llvm::handleAllErrors(std::move(error),
375                         [&exitCode](const llvm::ErrorInfoBase &info) {
376                           llvm::errs() << "Error: ";
377                           info.log(llvm::errs());
378                           llvm::errs() << '\n';
379                           exitCode = EXIT_FAILURE;
380                         });
381 
382   return exitCode;
383 }
384