1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a library that provides a shared implementation for command line
10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
11 // IR before JIT-compiling and executing the latter.
12 //
13 // The translation can be customized by providing an MLIR to MLIR
14 // transformation.
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/JitRunner.h"
18 
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/ExecutionEngine/ExecutionEngine.h"
21 #include "mlir/ExecutionEngine/OptUtils.h"
22 #include "mlir/IR/BuiltinTypes.h"
23 #include "mlir/IR/MLIRContext.h"
24 #include "mlir/Parser/Parser.h"
25 #include "mlir/Support/FileUtilities.h"
26 
27 #include "llvm/ADT/STLExtras.h"
28 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
29 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/LLVMContext.h"
32 #include "llvm/IR/LegacyPassNameParser.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Support/FileUtilities.h"
35 #include "llvm/Support/SourceMgr.h"
36 #include "llvm/Support/StringSaver.h"
37 #include "llvm/Support/ToolOutputFile.h"
38 #include <cstdint>
39 #include <numeric>
40 #include <utility>
41 
42 using namespace mlir;
43 using llvm::Error;
44 
45 namespace {
46 /// This options struct prevents the need for global static initializers, and
47 /// is only initialized if the JITRunner is invoked.
48 struct Options {
49   llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional,
50                                            llvm::cl::desc("<input file>"),
51                                            llvm::cl::init("-")};
52   llvm::cl::opt<std::string> mainFuncName{
53       "e", llvm::cl::desc("The function to be called"),
54       llvm::cl::value_desc("<function name>"), llvm::cl::init("main")};
55   llvm::cl::opt<std::string> mainFuncType{
56       "entry-point-result",
57       llvm::cl::desc("Textual description of the function type to be called"),
58       llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")};
59 
60   llvm::cl::OptionCategory optFlags{"opt-like flags"};
61 
62   // CLI variables for -On options.
63   llvm::cl::opt<bool> optO0{"O0",
64                             llvm::cl::desc("Run opt passes and codegen at O0"),
65                             llvm::cl::cat(optFlags)};
66   llvm::cl::opt<bool> optO1{"O1",
67                             llvm::cl::desc("Run opt passes and codegen at O1"),
68                             llvm::cl::cat(optFlags)};
69   llvm::cl::opt<bool> optO2{"O2",
70                             llvm::cl::desc("Run opt passes and codegen at O2"),
71                             llvm::cl::cat(optFlags)};
72   llvm::cl::opt<bool> optO3{"O3",
73                             llvm::cl::desc("Run opt passes and codegen at O3"),
74                             llvm::cl::cat(optFlags)};
75 
76   llvm::cl::OptionCategory clOptionsCategory{"linking options"};
77   llvm::cl::list<std::string> clSharedLibs{
78       "shared-libs", llvm::cl::desc("Libraries to link dynamically"),
79       llvm::cl::MiscFlags::CommaSeparated, llvm::cl::cat(clOptionsCategory)};
80 
81   /// CLI variables for debugging.
82   llvm::cl::opt<bool> dumpObjectFile{
83       "dump-object-file",
84       llvm::cl::desc("Dump JITted-compiled object to file specified with "
85                      "-object-filename (<input file>.o by default).")};
86 
87   llvm::cl::opt<std::string> objectFilename{
88       "object-filename",
89       llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
90 
91   llvm::cl::opt<bool> hostSupportsJit{"host-supports-jit",
92                                       llvm::cl::desc("Report host JIT support"),
93                                       llvm::cl::Hidden};
94 };
95 
96 struct CompileAndExecuteConfig {
97   /// LLVM module transformer that is passed to ExecutionEngine.
98   std::function<llvm::Error(llvm::Module *)> transformer;
99 
100   /// A custom function that is passed to ExecutionEngine. It processes MLIR
101   /// module and creates LLVM IR module.
102   llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
103                                                    llvm::LLVMContext &)>
104       llvmModuleBuilder;
105 
106   /// A custom function that is passed to ExecutinEngine to register symbols at
107   /// runtime.
108   llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
109       runtimeSymbolMap;
110 };
111 
112 } // namespace
113 
parseMLIRInput(StringRef inputFilename,MLIRContext * context)114 static OwningOpRef<ModuleOp> parseMLIRInput(StringRef inputFilename,
115                                             MLIRContext *context) {
116   // Set up the input file.
117   std::string errorMessage;
118   auto file = openInputFile(inputFilename, &errorMessage);
119   if (!file) {
120     llvm::errs() << errorMessage << "\n";
121     return nullptr;
122   }
123 
124   llvm::SourceMgr sourceMgr;
125   sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
126   return parseSourceFile<ModuleOp>(sourceMgr, context);
127 }
128 
makeStringError(const Twine & message)129 static inline Error makeStringError(const Twine &message) {
130   return llvm::make_error<llvm::StringError>(message.str(),
131                                              llvm::inconvertibleErrorCode());
132 }
133 
getCommandLineOptLevel(Options & options)134 static Optional<unsigned> getCommandLineOptLevel(Options &options) {
135   Optional<unsigned> optLevel;
136   SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
137       options.optO0, options.optO1, options.optO2, options.optO3};
138 
139   // Determine if there is an optimization flag present.
140   for (unsigned j = 0; j < 4; ++j) {
141     auto &flag = optFlags[j].get();
142     if (flag) {
143       optLevel = j;
144       break;
145     }
146   }
147   return optLevel;
148 }
149 
150 // JIT-compile the given module and run "entryPoint" with "args" as arguments.
compileAndExecute(Options & options,ModuleOp module,StringRef entryPoint,CompileAndExecuteConfig config,void ** args)151 static Error compileAndExecute(Options &options, ModuleOp module,
152                                StringRef entryPoint,
153                                CompileAndExecuteConfig config, void **args) {
154   Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel;
155   if (auto clOptLevel = getCommandLineOptLevel(options))
156     jitCodeGenOptLevel = static_cast<llvm::CodeGenOpt::Level>(*clOptLevel);
157 
158   // If shared library implements custom mlir-runner library init and destroy
159   // functions, we'll use them to register the library with the execution
160   // engine. Otherwise we'll pass library directly to the execution engine.
161   SmallVector<SmallString<256>, 4> libPaths;
162 
163   // Use absolute library path so that gdb can find the symbol table.
164   transform(
165       options.clSharedLibs, std::back_inserter(libPaths),
166       [](std::string libPath) {
167         SmallString<256> absPath(libPath.begin(), libPath.end());
168         cantFail(llvm::errorCodeToError(llvm::sys::fs::make_absolute(absPath)));
169         return absPath;
170       });
171 
172   // Libraries that we'll pass to the ExecutionEngine for loading.
173   SmallVector<StringRef, 4> executionEngineLibs;
174 
175   using MlirRunnerInitFn = void (*)(llvm::StringMap<void *> &);
176   using MlirRunnerDestroyFn = void (*)();
177 
178   llvm::StringMap<void *> exportSymbols;
179   SmallVector<MlirRunnerDestroyFn> destroyFns;
180 
181   // Handle libraries that do support mlir-runner init/destroy callbacks.
182   for (auto &libPath : libPaths) {
183     auto lib = llvm::sys::DynamicLibrary::getPermanentLibrary(libPath.c_str());
184     void *initSym = lib.getAddressOfSymbol("__mlir_runner_init");
185     void *destroySim = lib.getAddressOfSymbol("__mlir_runner_destroy");
186 
187     // Library does not support mlir runner, load it with ExecutionEngine.
188     if (!initSym || !destroySim) {
189       executionEngineLibs.push_back(libPath);
190       continue;
191     }
192 
193     auto initFn = reinterpret_cast<MlirRunnerInitFn>(initSym);
194     initFn(exportSymbols);
195 
196     auto destroyFn = reinterpret_cast<MlirRunnerDestroyFn>(destroySim);
197     destroyFns.push_back(destroyFn);
198   }
199 
200   // Build a runtime symbol map from the config and exported symbols.
201   auto runtimeSymbolMap = [&](llvm::orc::MangleAndInterner interner) {
202     auto symbolMap = config.runtimeSymbolMap ? config.runtimeSymbolMap(interner)
203                                              : llvm::orc::SymbolMap();
204     for (auto &exportSymbol : exportSymbols)
205       symbolMap[interner(exportSymbol.getKey())] =
206           llvm::JITEvaluatedSymbol::fromPointer(exportSymbol.getValue());
207     return symbolMap;
208   };
209 
210   mlir::ExecutionEngineOptions engineOptions;
211   engineOptions.llvmModuleBuilder = config.llvmModuleBuilder;
212   if (config.transformer)
213     engineOptions.transformer = config.transformer;
214   engineOptions.jitCodeGenOptLevel = jitCodeGenOptLevel;
215   engineOptions.sharedLibPaths = executionEngineLibs;
216   engineOptions.enableObjectCache = true;
217   auto expectedEngine = mlir::ExecutionEngine::create(module, engineOptions);
218   if (!expectedEngine)
219     return expectedEngine.takeError();
220 
221   auto engine = std::move(*expectedEngine);
222   engine->registerSymbols(runtimeSymbolMap);
223 
224   auto expectedFPtr = engine->lookupPacked(entryPoint);
225   if (!expectedFPtr)
226     return expectedFPtr.takeError();
227 
228   if (options.dumpObjectFile)
229     engine->dumpToObjectFile(options.objectFilename.empty()
230                                  ? options.inputFilename + ".o"
231                                  : options.objectFilename);
232 
233   void (*fptr)(void **) = *expectedFPtr;
234   (*fptr)(args);
235 
236   // Run all dynamic library destroy callbacks to prepare for the shutdown.
237   llvm::for_each(destroyFns, [](MlirRunnerDestroyFn destroy) { destroy(); });
238 
239   return Error::success();
240 }
241 
compileAndExecuteVoidFunction(Options & options,ModuleOp module,StringRef entryPoint,CompileAndExecuteConfig config)242 static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module,
243                                            StringRef entryPoint,
244                                            CompileAndExecuteConfig config) {
245   auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
246   if (!mainFunction || mainFunction.empty())
247     return makeStringError("entry point not found");
248   void *empty = nullptr;
249   return compileAndExecute(options, module, entryPoint, std::move(config),
250                            &empty);
251 }
252 
253 template <typename Type>
254 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction);
255 template <>
checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction)256 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
257   auto resultType = mainFunction.getFunctionType()
258                         .cast<LLVM::LLVMFunctionType>()
259                         .getReturnType()
260                         .dyn_cast<IntegerType>();
261   if (!resultType || resultType.getWidth() != 32)
262     return makeStringError("only single i32 function result supported");
263   return Error::success();
264 }
265 template <>
checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction)266 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
267   auto resultType = mainFunction.getFunctionType()
268                         .cast<LLVM::LLVMFunctionType>()
269                         .getReturnType()
270                         .dyn_cast<IntegerType>();
271   if (!resultType || resultType.getWidth() != 64)
272     return makeStringError("only single i64 function result supported");
273   return Error::success();
274 }
275 template <>
checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction)276 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
277   if (!mainFunction.getFunctionType()
278            .cast<LLVM::LLVMFunctionType>()
279            .getReturnType()
280            .isa<Float32Type>())
281     return makeStringError("only single f32 function result supported");
282   return Error::success();
283 }
284 template <typename Type>
compileAndExecuteSingleReturnFunction(Options & options,ModuleOp module,StringRef entryPoint,CompileAndExecuteConfig config)285 Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module,
286                                             StringRef entryPoint,
287                                             CompileAndExecuteConfig config) {
288   auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
289   if (!mainFunction || mainFunction.isExternal())
290     return makeStringError("entry point not found");
291 
292   if (mainFunction.getFunctionType()
293           .cast<LLVM::LLVMFunctionType>()
294           .getNumParams() != 0)
295     return makeStringError("function inputs not supported");
296 
297   if (Error error = checkCompatibleReturnType<Type>(mainFunction))
298     return error;
299 
300   Type res;
301   struct {
302     void *data;
303   } data;
304   data.data = &res;
305   if (auto error = compileAndExecute(options, module, entryPoint,
306                                      std::move(config), (void **)&data))
307     return error;
308 
309   // Intentional printing of the output so we can test.
310   llvm::outs() << res << '\n';
311 
312   return Error::success();
313 }
314 
315 /// Entry point for all CPU runners. Expects the common argc/argv arguments for
316 /// standard C++ main functions.
JitRunnerMain(int argc,char ** argv,const DialectRegistry & registry,JitRunnerConfig config)317 int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry &registry,
318                         JitRunnerConfig config) {
319   // Create the options struct containing the command line options for the
320   // runner. This must come before the command line options are parsed.
321   Options options;
322   llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
323 
324   if (options.hostSupportsJit) {
325     auto J = llvm::orc::LLJITBuilder().create();
326     if (J)
327       llvm::outs() << "true\n";
328     else {
329       llvm::consumeError(J.takeError());
330       llvm::outs() << "false\n";
331     }
332     return 0;
333   }
334 
335   Optional<unsigned> optLevel = getCommandLineOptLevel(options);
336   SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
337       options.optO0, options.optO1, options.optO2, options.optO3};
338 
339   MLIRContext context(registry);
340 
341   auto m = parseMLIRInput(options.inputFilename, &context);
342   if (!m) {
343     llvm::errs() << "could not parse the input IR\n";
344     return 1;
345   }
346 
347   if (config.mlirTransformer)
348     if (failed(config.mlirTransformer(m.get())))
349       return EXIT_FAILURE;
350 
351   auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
352   if (!tmBuilderOrError) {
353     llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
354     return EXIT_FAILURE;
355   }
356   auto tmOrError = tmBuilderOrError->createTargetMachine();
357   if (!tmOrError) {
358     llvm::errs() << "Failed to create a TargetMachine for the host\n";
359     return EXIT_FAILURE;
360   }
361 
362   CompileAndExecuteConfig compileAndExecuteConfig;
363   if (optLevel) {
364     compileAndExecuteConfig.transformer = mlir::makeOptimizingTransformer(
365         *optLevel, /*sizeLevel=*/0, /*targetMachine=*/tmOrError->get());
366   }
367   compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder;
368   compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap;
369 
370   // Get the function used to compile and execute the module.
371   using CompileAndExecuteFnT =
372       Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig);
373   auto compileAndExecuteFn =
374       StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
375           .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
376           .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>)
377           .Case("f32", compileAndExecuteSingleReturnFunction<float>)
378           .Case("void", compileAndExecuteVoidFunction)
379           .Default(nullptr);
380 
381   Error error = compileAndExecuteFn
382                     ? compileAndExecuteFn(options, m.get(),
383                                           options.mainFuncName.getValue(),
384                                           compileAndExecuteConfig)
385                     : makeStringError("unsupported function type");
386 
387   int exitCode = EXIT_SUCCESS;
388   llvm::handleAllErrors(std::move(error),
389                         [&exitCode](const llvm::ErrorInfoBase &info) {
390                           llvm::errs() << "Error: ";
391                           info.log(llvm::errs());
392                           llvm::errs() << '\n';
393                           exitCode = EXIT_FAILURE;
394                         });
395 
396   return exitCode;
397 }
398