1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a library that provides a shared implementation for command line
10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
11 // IR before JIT-compiling and executing the latter.
12 //
13 // The translation can be customized by providing an MLIR to MLIR
14 // transformation.
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/JitRunner.h"
18 
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/ExecutionEngine/ExecutionEngine.h"
21 #include "mlir/ExecutionEngine/OptUtils.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Module.h"
24 #include "mlir/IR/StandardTypes.h"
25 #include "mlir/InitAllDialects.h"
26 #include "mlir/Parser.h"
27 #include "mlir/Support/FileUtilities.h"
28 
29 #include "llvm/ADT/STLExtras.h"
30 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
31 #include "llvm/IR/IRBuilder.h"
32 #include "llvm/IR/LLVMContext.h"
33 #include "llvm/IR/LegacyPassNameParser.h"
34 #include "llvm/IR/Module.h"
35 #include "llvm/Support/CommandLine.h"
36 #include "llvm/Support/FileUtilities.h"
37 #include "llvm/Support/SourceMgr.h"
38 #include "llvm/Support/StringSaver.h"
39 #include "llvm/Support/ToolOutputFile.h"
40 #include <cstdint>
41 #include <numeric>
42 
43 using namespace mlir;
44 using llvm::Error;
45 
46 namespace {
47 /// This options struct prevents the need for global static initializers, and
48 /// is only initialized if the JITRunner is invoked.
49 struct Options {
50   llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional,
51                                            llvm::cl::desc("<input file>"),
52                                            llvm::cl::init("-")};
53   llvm::cl::opt<std::string> mainFuncName{
54       "e", llvm::cl::desc("The function to be called"),
55       llvm::cl::value_desc("<function name>"), llvm::cl::init("main")};
56   llvm::cl::opt<std::string> mainFuncType{
57       "entry-point-result",
58       llvm::cl::desc("Textual description of the function type to be called"),
59       llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")};
60 
61   llvm::cl::OptionCategory optFlags{"opt-like flags"};
62 
63   // CLI list of pass information
64   llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser> llvmPasses{
65       llvm::cl::desc("LLVM optimizing passes to run"), llvm::cl::cat(optFlags)};
66 
67   // CLI variables for -On options.
68   llvm::cl::opt<bool> optO0{"O0",
69                             llvm::cl::desc("Run opt passes and codegen at O0"),
70                             llvm::cl::cat(optFlags)};
71   llvm::cl::opt<bool> optO1{"O1",
72                             llvm::cl::desc("Run opt passes and codegen at O1"),
73                             llvm::cl::cat(optFlags)};
74   llvm::cl::opt<bool> optO2{"O2",
75                             llvm::cl::desc("Run opt passes and codegen at O2"),
76                             llvm::cl::cat(optFlags)};
77   llvm::cl::opt<bool> optO3{"O3",
78                             llvm::cl::desc("Run opt passes and codegen at O3"),
79                             llvm::cl::cat(optFlags)};
80 
81   llvm::cl::OptionCategory clOptionsCategory{"linking options"};
82   llvm::cl::list<std::string> clSharedLibs{
83       "shared-libs", llvm::cl::desc("Libraries to link dynamically"),
84       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
85       llvm::cl::cat(clOptionsCategory)};
86 
87   /// CLI variables for debugging.
88   llvm::cl::opt<bool> dumpObjectFile{
89       "dump-object-file",
90       llvm::cl::desc("Dump JITted-compiled object to file specified with "
91                      "-object-filename (<input file>.o by default).")};
92 
93   llvm::cl::opt<std::string> objectFilename{
94       "object-filename",
95       llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
96 };
97 } // end anonymous namespace
98 
99 static OwningModuleRef parseMLIRInput(StringRef inputFilename,
100                                       MLIRContext *context) {
101   // Set up the input file.
102   std::string errorMessage;
103   auto file = openInputFile(inputFilename, &errorMessage);
104   if (!file) {
105     llvm::errs() << errorMessage << "\n";
106     return nullptr;
107   }
108 
109   llvm::SourceMgr sourceMgr;
110   sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
111   return OwningModuleRef(parseSourceFile(sourceMgr, context));
112 }
113 
114 static inline Error make_string_error(const Twine &message) {
115   return llvm::make_error<llvm::StringError>(message.str(),
116                                              llvm::inconvertibleErrorCode());
117 }
118 
119 static Optional<unsigned> getCommandLineOptLevel(Options &options) {
120   Optional<unsigned> optLevel;
121   SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
122       options.optO0, options.optO1, options.optO2, options.optO3};
123 
124   // Determine if there is an optimization flag present.
125   for (unsigned j = 0; j < 4; ++j) {
126     auto &flag = optFlags[j].get();
127     if (flag) {
128       optLevel = j;
129       break;
130     }
131   }
132   return optLevel;
133 }
134 
135 // JIT-compile the given module and run "entryPoint" with "args" as arguments.
136 static Error
137 compileAndExecute(Options &options, ModuleOp module, StringRef entryPoint,
138                   std::function<llvm::Error(llvm::Module *)> transformer,
139                   void **args) {
140   Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel;
141   if (auto clOptLevel = getCommandLineOptLevel(options))
142     jitCodeGenOptLevel =
143         static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue());
144   SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(),
145                                  options.clSharedLibs.end());
146   auto expectedEngine = mlir::ExecutionEngine::create(module, transformer,
147                                                       jitCodeGenOptLevel, libs);
148   if (!expectedEngine)
149     return expectedEngine.takeError();
150 
151   auto engine = std::move(*expectedEngine);
152   auto expectedFPtr = engine->lookup(entryPoint);
153   if (!expectedFPtr)
154     return expectedFPtr.takeError();
155 
156   if (options.dumpObjectFile)
157     engine->dumpToObjectFile(options.objectFilename.empty()
158                                  ? options.inputFilename + ".o"
159                                  : options.objectFilename);
160 
161   void (*fptr)(void **) = *expectedFPtr;
162   (*fptr)(args);
163 
164   return Error::success();
165 }
166 
167 static Error compileAndExecuteVoidFunction(
168     Options &options, ModuleOp module, StringRef entryPoint,
169     std::function<llvm::Error(llvm::Module *)> transformer) {
170   auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
171   if (!mainFunction || mainFunction.empty())
172     return make_string_error("entry point not found");
173   void *empty = nullptr;
174   return compileAndExecute(options, module, entryPoint, transformer, &empty);
175 }
176 
177 template <typename Type>
178 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction);
179 template <>
180 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
181   if (!mainFunction.getType().getFunctionResultType().isIntegerTy(32))
182     return make_string_error("only single llvm.i32 function result supported");
183   return Error::success();
184 }
185 template <>
186 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
187   if (!mainFunction.getType().getFunctionResultType().isIntegerTy(64))
188     return make_string_error("only single llvm.i64 function result supported");
189   return Error::success();
190 }
191 template <>
192 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
193   if (!mainFunction.getType().getFunctionResultType().isFloatTy())
194     return make_string_error("only single llvm.f32 function result supported");
195   return Error::success();
196 }
197 template <typename Type>
198 Error compileAndExecuteSingleReturnFunction(
199     Options &options, ModuleOp module, StringRef entryPoint,
200     std::function<llvm::Error(llvm::Module *)> transformer) {
201   auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
202   if (!mainFunction || mainFunction.isExternal())
203     return make_string_error("entry point not found");
204 
205   if (mainFunction.getType().getFunctionNumParams() != 0)
206     return make_string_error("function inputs not supported");
207 
208   if (Error error = checkCompatibleReturnType<Type>(mainFunction))
209     return error;
210 
211   Type res;
212   struct {
213     void *data;
214   } data;
215   data.data = &res;
216   if (auto error = compileAndExecute(options, module, entryPoint, transformer,
217                                      (void **)&data))
218     return error;
219 
220   // Intentional printing of the output so we can test.
221   llvm::outs() << res << '\n';
222 
223   return Error::success();
224 }
225 
226 /// Entry point for all CPU runners. Expects the common argc/argv
227 /// arguments for standard C++ main functions and an mlirTransformer.
228 /// The latter is applied after parsing the input into MLIR IR and
229 /// before passing the MLIR module to the ExecutionEngine.
230 int mlir::JitRunnerMain(
231     int argc, char **argv,
232     function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer) {
233   // Create the options struct containing the command line options for the
234   // runner. This must come before the command line options are parsed.
235   Options options;
236   llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
237 
238   Optional<unsigned> optLevel = getCommandLineOptLevel(options);
239   SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
240       options.optO0, options.optO1, options.optO2, options.optO3};
241   unsigned optCLIPosition = 0;
242   // Determine if there is an optimization flag present, and its CLI position
243   // (optCLIPosition).
244   for (unsigned j = 0; j < 4; ++j) {
245     auto &flag = optFlags[j].get();
246     if (flag) {
247       optCLIPosition = flag.getPosition();
248       break;
249     }
250   }
251   // Generate vector of pass information, plus the index at which we should
252   // insert any optimization passes in that vector (optPosition).
253   SmallVector<const llvm::PassInfo *, 4> passes;
254   unsigned optPosition = 0;
255   for (unsigned i = 0, e = options.llvmPasses.size(); i < e; ++i) {
256     passes.push_back(options.llvmPasses[i]);
257     if (optCLIPosition < options.llvmPasses.getPosition(i)) {
258       optPosition = i;
259       optCLIPosition = UINT_MAX; // To ensure we never insert again
260     }
261   }
262 
263   MLIRContext context;
264   registerAllDialects(context.getDialectRegistry());
265 
266   auto m = parseMLIRInput(options.inputFilename, &context);
267   if (!m) {
268     llvm::errs() << "could not parse the input IR\n";
269     return 1;
270   }
271 
272   if (mlirTransformer)
273     if (failed(mlirTransformer(m.get())))
274       return EXIT_FAILURE;
275 
276   auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
277   if (!tmBuilderOrError) {
278     llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
279     return EXIT_FAILURE;
280   }
281   auto tmOrError = tmBuilderOrError->createTargetMachine();
282   if (!tmOrError) {
283     llvm::errs() << "Failed to create a TargetMachine for the host\n";
284     return EXIT_FAILURE;
285   }
286 
287   auto transformer = mlir::makeLLVMPassesTransformer(
288       passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition);
289 
290   // Get the function used to compile and execute the module.
291   using CompileAndExecuteFnT =
292       Error (*)(Options &, ModuleOp, StringRef,
293                 std::function<llvm::Error(llvm::Module *)>);
294   auto compileAndExecuteFn =
295       StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
296           .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
297           .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>)
298           .Case("f32", compileAndExecuteSingleReturnFunction<float>)
299           .Case("void", compileAndExecuteVoidFunction)
300           .Default(nullptr);
301 
302   Error error =
303       compileAndExecuteFn
304           ? compileAndExecuteFn(options, m.get(),
305                                 options.mainFuncName.getValue(), transformer)
306           : make_string_error("unsupported function type");
307 
308   int exitCode = EXIT_SUCCESS;
309   llvm::handleAllErrors(std::move(error),
310                         [&exitCode](const llvm::ErrorInfoBase &info) {
311                           llvm::errs() << "Error: ";
312                           info.log(llvm::errs());
313                           llvm::errs() << '\n';
314                           exitCode = EXIT_FAILURE;
315                         });
316 
317   return exitCode;
318 }
319