1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a library that provides a shared implementation for command line
10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
11 // IR before JIT-compiling and executing the latter.
12 //
13 // The translation can be customized by providing an MLIR to MLIR
14 // transformation.
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/JitRunner.h"
18 
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/ExecutionEngine/ExecutionEngine.h"
21 #include "mlir/ExecutionEngine/OptUtils.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/InitAllDialects.h"
25 #include "mlir/Parser.h"
26 #include "mlir/Support/FileUtilities.h"
27 
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/LLVMContext.h"
32 #include "llvm/IR/LegacyPassNameParser.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/FileUtilities.h"
35 #include "llvm/Support/SourceMgr.h"
36 #include "llvm/Support/StringSaver.h"
37 #include "llvm/Support/ToolOutputFile.h"
38 #include <cstdint>
39 #include <numeric>
40 
41 using namespace mlir;
42 using llvm::Error;
43 
44 namespace {
45 /// This options struct prevents the need for global static initializers, and
46 /// is only initialized if the JITRunner is invoked.
47 struct Options {
48   llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional,
49                                            llvm::cl::desc("<input file>"),
50                                            llvm::cl::init("-")};
51   llvm::cl::opt<std::string> mainFuncName{
52       "e", llvm::cl::desc("The function to be called"),
53       llvm::cl::value_desc("<function name>"), llvm::cl::init("main")};
54   llvm::cl::opt<std::string> mainFuncType{
55       "entry-point-result",
56       llvm::cl::desc("Textual description of the function type to be called"),
57       llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")};
58 
59   llvm::cl::OptionCategory optFlags{"opt-like flags"};
60 
61   // CLI list of pass information
62   llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser> llvmPasses{
63       llvm::cl::desc("LLVM optimizing passes to run"), llvm::cl::cat(optFlags)};
64 
65   // CLI variables for -On options.
66   llvm::cl::opt<bool> optO0{"O0",
67                             llvm::cl::desc("Run opt passes and codegen at O0"),
68                             llvm::cl::cat(optFlags)};
69   llvm::cl::opt<bool> optO1{"O1",
70                             llvm::cl::desc("Run opt passes and codegen at O1"),
71                             llvm::cl::cat(optFlags)};
72   llvm::cl::opt<bool> optO2{"O2",
73                             llvm::cl::desc("Run opt passes and codegen at O2"),
74                             llvm::cl::cat(optFlags)};
75   llvm::cl::opt<bool> optO3{"O3",
76                             llvm::cl::desc("Run opt passes and codegen at O3"),
77                             llvm::cl::cat(optFlags)};
78 
79   llvm::cl::OptionCategory clOptionsCategory{"linking options"};
80   llvm::cl::list<std::string> clSharedLibs{
81       "shared-libs", llvm::cl::desc("Libraries to link dynamically"),
82       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
83       llvm::cl::cat(clOptionsCategory)};
84 
85   /// CLI variables for debugging.
86   llvm::cl::opt<bool> dumpObjectFile{
87       "dump-object-file",
88       llvm::cl::desc("Dump JITted-compiled object to file specified with "
89                      "-object-filename (<input file>.o by default).")};
90 
91   llvm::cl::opt<std::string> objectFilename{
92       "object-filename",
93       llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
94 };
95 
96 struct CompileAndExecuteConfig {
97   /// LLVM module transformer that is passed to ExecutionEngine.
98   llvm::function_ref<llvm::Error(llvm::Module *)> transformer;
99 
100   /// A custom function that is passed to ExecutionEngine. It processes MLIR
101   /// module and creates LLVM IR module.
102   llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
103                                                    llvm::LLVMContext &)>
104       llvmModuleBuilder;
105 
106   /// A custom function that is passed to ExecutinEngine to register symbols at
107   /// runtime.
108   llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
109       runtimeSymbolMap;
110 };
111 
112 } // end anonymous namespace
113 
114 static OwningModuleRef parseMLIRInput(StringRef inputFilename,
115                                       MLIRContext *context) {
116   // Set up the input file.
117   std::string errorMessage;
118   auto file = openInputFile(inputFilename, &errorMessage);
119   if (!file) {
120     llvm::errs() << errorMessage << "\n";
121     return nullptr;
122   }
123 
124   llvm::SourceMgr sourceMgr;
125   sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
126   return OwningModuleRef(parseSourceFile(sourceMgr, context));
127 }
128 
129 static inline Error make_string_error(const Twine &message) {
130   return llvm::make_error<llvm::StringError>(message.str(),
131                                              llvm::inconvertibleErrorCode());
132 }
133 
134 static Optional<unsigned> getCommandLineOptLevel(Options &options) {
135   Optional<unsigned> optLevel;
136   SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
137       options.optO0, options.optO1, options.optO2, options.optO3};
138 
139   // Determine if there is an optimization flag present.
140   for (unsigned j = 0; j < 4; ++j) {
141     auto &flag = optFlags[j].get();
142     if (flag) {
143       optLevel = j;
144       break;
145     }
146   }
147   return optLevel;
148 }
149 
150 // JIT-compile the given module and run "entryPoint" with "args" as arguments.
151 static Error compileAndExecute(Options &options, ModuleOp module,
152                                StringRef entryPoint,
153                                CompileAndExecuteConfig config, void **args) {
154   Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel;
155   if (auto clOptLevel = getCommandLineOptLevel(options))
156     jitCodeGenOptLevel =
157         static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue());
158 
159   // If shared library implements custom mlir-runner library init and destroy
160   // functions, we'll use them to register the library with the execution
161   // engine. Otherwise we'll pass library directly to the execution engine.
162   SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(),
163                                  options.clSharedLibs.end());
164 
165   // Libraries that we'll pass to the ExecutionEngine for loading.
166   SmallVector<StringRef, 4> executionEngineLibs;
167 
168   using MlirRunnerInitFn = void (*)(llvm::StringMap<void *> &);
169   using MlirRunnerDestroyFn = void (*)();
170 
171   llvm::StringMap<void *> exportSymbols;
172   SmallVector<MlirRunnerDestroyFn> destroyFns;
173 
174   // Handle libraries that do support mlir-runner init/destroy callbacks.
175   for (auto libPath : libs) {
176     auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.data());
177     void *initSym = lib.getAddressOfSymbol("__mlir_runner_init");
178     void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy");
179 
180     // Library does not support mlir runner, load it with ExecutionEngine.
181     if (!initSym || !destroySim) {
182       executionEngineLibs.push_back(libPath);
183       continue;
184     }
185 
186     auto initFn = reinterpret_cast<MlirRunnerInitFn>(initSym);
187     initFn(exportSymbols);
188 
189     auto destroyFn = reinterpret_cast<MlirRunnerDestroyFn>(destroySim);
190     destroyFns.push_back(destroyFn);
191   }
192 
193   // Build a runtime symbol map from the config and exported symbols.
194   auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
195     auto symbolMap = config.runtimeSymbolMap ? config.runtimeSymbolMap(interner)
196                                              : llvm::orc::SymbolMap();
197     for (auto &exportSymbol : exportSymbols)
198       symbolMap[interner(exportSymbol.getKey())] =
199           llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue());
200     return symbolMap;
201   };
202 
203   auto expectedEngine = mlir::ExecutionEngine::create(
204       module, config.llvmModuleBuilder, config.transformer, jitCodeGenOptLevel,
205       executionEngineLibs);
206   if (!expectedEngine)
207     return expectedEngine.takeError();
208 
209   auto engine = std::move(*expectedEngine);
210   engine->registerSymbols(runtimeSymbolMap);
211 
212   auto expectedFPtr = engine->lookup(entryPoint);
213   if (!expectedFPtr)
214     return expectedFPtr.takeError();
215 
216   if (options.dumpObjectFile)
217     engine->dumpToObjectFile(options.objectFilename.empty()
218                                  ? options.inputFilename + ".o"
219                                  : options.objectFilename);
220 
221   void (*fptr)(void **) = *expectedFPtr;
222   (*fptr)(args);
223 
224   // Run all dynamic library destroy callbacks to prepare for the shutdown.
225   llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); });
226 
227   return Error::success();
228 }
229 
230 static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module,
231                                            StringRef entryPoint,
232                                            CompileAndExecuteConfig config) {
233   auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
234   if (!mainFunction || mainFunction.empty())
235     return make_string_error("entry point not found");
236   void *empty = nullptr;
237   return compileAndExecute(options, module, entryPoint, config, &empty);
238 }
239 
240 template <typename Type>
241 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction);
242 template <>
243 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
244   auto resultType = mainFunction.getType()
245                         .cast<LLVM::LLVMFunctionType>()
246                         .getReturnType()
247                         .dyn_cast<IntegerType>();
248   if (!resultType || resultType.getWidth() != 32)
249     return make_string_error("only single i32 function result supported");
250   return Error::success();
251 }
252 template <>
253 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
254   auto resultType = mainFunction.getType()
255                         .cast<LLVM::LLVMFunctionType>()
256                         .getReturnType()
257                         .dyn_cast<IntegerType>();
258   if (!resultType || resultType.getWidth() != 64)
259     return make_string_error("only single i64 function result supported");
260   return Error::success();
261 }
262 template <>
263 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
264   if (!mainFunction.getType()
265            .cast<LLVM::LLVMFunctionType>()
266            .getReturnType()
267            .isa<Float32Type>())
268     return make_string_error("only single f32 function result supported");
269   return Error::success();
270 }
271 template <typename Type>
272 Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module,
273                                             StringRef entryPoint,
274                                             CompileAndExecuteConfig config) {
275   auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
276   if (!mainFunction || mainFunction.isExternal())
277     return make_string_error("entry point not found");
278 
279   if (mainFunction.getType().cast<LLVM::LLVMFunctionType>().getNumParams() != 0)
280     return make_string_error("function inputs not supported");
281 
282   if (Error error = checkCompatibleReturnType<Type>(mainFunction))
283     return error;
284 
285   Type res;
286   struct {
287     void *data;
288   } data;
289   data.data = &res;
290   if (auto error = compileAndExecute(options, module, entryPoint, config,
291                                      (void **)&data))
292     return error;
293 
294   // Intentional printing of the output so we can test.
295   llvm::outs() << res << '\n';
296 
297   return Error::success();
298 }
299 
300 /// Entry point for all CPU runners. Expects the common argc/argv arguments for
301 /// standard C++ main functions.
302 int mlir::JitRunnerMain(int argc, char **argv, JitRunnerConfig config) {
303   // Create the options struct containing the command line options for the
304   // runner. This must come before the command line options are parsed.
305   Options options;
306   llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
307 
308   Optional<unsigned> optLevel = getCommandLineOptLevel(options);
309   SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
310       options.optO0, options.optO1, options.optO2, options.optO3};
311   unsigned optCLIPosition = 0;
312   // Determine if there is an optimization flag present, and its CLI position
313   // (optCLIPosition).
314   for (unsigned j = 0; j < 4; ++j) {
315     auto &flag = optFlags[j].get();
316     if (flag) {
317       optCLIPosition = flag.getPosition();
318       break;
319     }
320   }
321   // Generate vector of pass information, plus the index at which we should
322   // insert any optimization passes in that vector (optPosition).
323   SmallVector<const llvm::PassInfo *, 4> passes;
324   unsigned optPosition = 0;
325   for (unsigned i = 0, e = options.llvmPasses.size(); i < e; ++i) {
326     passes.push_back(options.llvmPasses[i]);
327     if (optCLIPosition < options.llvmPasses.getPosition(i)) {
328       optPosition = i;
329       optCLIPosition = UINT_MAX; // To ensure we never insert again
330     }
331   }
332 
333   MLIRContext context;
334   registerAllDialects(context.getDialectRegistry());
335 
336   auto m = parseMLIRInput(options.inputFilename, &context);
337   if (!m) {
338     llvm::errs() << "could not parse the input IR\n";
339     return 1;
340   }
341 
342   if (config.mlirTransformer)
343     if (failed(config.mlirTransformer(m.get())))
344       return EXIT_FAILURE;
345 
346   auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
347   if (!tmBuilderOrError) {
348     llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
349     return EXIT_FAILURE;
350   }
351   auto tmOrError = tmBuilderOrError->createTargetMachine();
352   if (!tmOrError) {
353     llvm::errs() << "Failed to create a TargetMachine for the host\n";
354     return EXIT_FAILURE;
355   }
356 
357   auto transformer = mlir::makeLLVMPassesTransformer(
358       passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition);
359 
360   CompileAndExecuteConfig compileAndExecuteConfig;
361   compileAndExecuteConfig.transformer = transformer;
362   compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder;
363   compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap;
364 
365   // Get the function used to compile and execute the module.
366   using CompileAndExecuteFnT =
367       Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig);
368   auto compileAndExecuteFn =
369       StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
370           .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
371           .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>)
372           .Case("f32", compileAndExecuteSingleReturnFunction<float>)
373           .Case("void", compileAndExecuteVoidFunction)
374           .Default(nullptr);
375 
376   Error error = compileAndExecuteFn
377                     ? compileAndExecuteFn(options, m.get(),
378                                           options.mainFuncName.getValue(),
379                                           compileAndExecuteConfig)
380                     : make_string_error("unsupported function type");
381 
382   int exitCode = EXIT_SUCCESS;
383   llvm::handleAllErrors(std::move(error),
384                         [&exitCode](const llvm::ErrorInfoBase &info) {
385                           llvm::errs() << "Error: ";
386                           info.log(llvm::errs());
387                           llvm::errs() << '\n';
388                           exitCode = EXIT_FAILURE;
389                         });
390 
391   return exitCode;
392 }
393