1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a library that provides a shared implementation for command line
10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
11 // IR before JIT-compiling and executing the latter.
12 //
13 // The translation can be customized by providing an MLIR to MLIR
14 // transformation.
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/JitRunner.h"
18 
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/ExecutionEngine/ExecutionEngine.h"
21 #include "mlir/ExecutionEngine/OptUtils.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Module.h"
24 #include "mlir/IR/StandardTypes.h"
25 #include "mlir/Parser.h"
26 #include "mlir/Support/FileUtilities.h"
27 
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/LLVMContext.h"
32 #include "llvm/IR/LegacyPassNameParser.h"
33 #include "llvm/IR/Module.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Support/FileUtilities.h"
36 #include "llvm/Support/SourceMgr.h"
37 #include "llvm/Support/StringSaver.h"
38 #include "llvm/Support/ToolOutputFile.h"
39 #include <numeric>
40 
41 using namespace mlir;
42 using llvm::Error;
43 
44 namespace {
45 /// This options struct prevents the need for global static initializers, and
46 /// is only initialized if the JITRunner is invoked.
47 struct Options {
48   llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional,
49                                            llvm::cl::desc("<input file>"),
50                                            llvm::cl::init("-")};
51   llvm::cl::opt<std::string> mainFuncName{
52       "e", llvm::cl::desc("The function to be called"),
53       llvm::cl::value_desc("<function name>"), llvm::cl::init("main")};
54   llvm::cl::opt<std::string> mainFuncType{
55       "entry-point-result",
56       llvm::cl::desc("Textual description of the function type to be called"),
57       llvm::cl::value_desc("f32 | void"), llvm::cl::init("f32")};
58 
59   llvm::cl::OptionCategory optFlags{"opt-like flags"};
60 
61   // CLI list of pass information
62   llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser> llvmPasses{
63       llvm::cl::desc("LLVM optimizing passes to run"), llvm::cl::cat(optFlags)};
64 
65   // CLI variables for -On options.
66   llvm::cl::opt<bool> optO0{"O0",
67                             llvm::cl::desc("Run opt passes and codegen at O0"),
68                             llvm::cl::cat(optFlags)};
69   llvm::cl::opt<bool> optO1{"O1",
70                             llvm::cl::desc("Run opt passes and codegen at O1"),
71                             llvm::cl::cat(optFlags)};
72   llvm::cl::opt<bool> optO2{"O2",
73                             llvm::cl::desc("Run opt passes and codegen at O2"),
74                             llvm::cl::cat(optFlags)};
75   llvm::cl::opt<bool> optO3{"O3",
76                             llvm::cl::desc("Run opt passes and codegen at O3"),
77                             llvm::cl::cat(optFlags)};
78 
79   llvm::cl::OptionCategory clOptionsCategory{"linking options"};
80   llvm::cl::list<std::string> clSharedLibs{
81       "shared-libs", llvm::cl::desc("Libraries to link dynamically"),
82       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
83       llvm::cl::cat(clOptionsCategory)};
84 
85   /// CLI variables for debugging.
86   llvm::cl::opt<bool> dumpObjectFile{
87       "dump-object-file",
88       llvm::cl::desc("Dump JITted-compiled object to file specified with "
89                      "-object-filename (<input file>.o by default).")};
90 
91   llvm::cl::opt<std::string> objectFilename{
92       "object-filename",
93       llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
94 };
95 } // end anonymous namespace
96 
97 static OwningModuleRef parseMLIRInput(StringRef inputFilename,
98                                       MLIRContext *context) {
99   // Set up the input file.
100   std::string errorMessage;
101   auto file = openInputFile(inputFilename, &errorMessage);
102   if (!file) {
103     llvm::errs() << errorMessage << "\n";
104     return nullptr;
105   }
106 
107   llvm::SourceMgr sourceMgr;
108   sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
109   return OwningModuleRef(parseSourceFile(sourceMgr, context));
110 }
111 
112 static inline Error make_string_error(const Twine &message) {
113   return llvm::make_error<llvm::StringError>(message.str(),
114                                              llvm::inconvertibleErrorCode());
115 }
116 
117 static Optional<unsigned> getCommandLineOptLevel(Options &options) {
118   Optional<unsigned> optLevel;
119   SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
120       options.optO0, options.optO1, options.optO2, options.optO3};
121 
122   // Determine if there is an optimization flag present.
123   for (unsigned j = 0; j < 4; ++j) {
124     auto &flag = optFlags[j].get();
125     if (flag) {
126       optLevel = j;
127       break;
128     }
129   }
130   return optLevel;
131 }
132 
133 // JIT-compile the given module and run "entryPoint" with "args" as arguments.
134 static Error
135 compileAndExecute(Options &options, ModuleOp module, StringRef entryPoint,
136                   std::function<llvm::Error(llvm::Module *)> transformer,
137                   void **args) {
138   Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel;
139   if (auto clOptLevel = getCommandLineOptLevel(options))
140     jitCodeGenOptLevel =
141         static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue());
142   SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(),
143                                  options.clSharedLibs.end());
144   auto expectedEngine = mlir::ExecutionEngine::create(module, transformer,
145                                                       jitCodeGenOptLevel, libs);
146   if (!expectedEngine)
147     return expectedEngine.takeError();
148 
149   auto engine = std::move(*expectedEngine);
150   auto expectedFPtr = engine->lookup(entryPoint);
151   if (!expectedFPtr)
152     return expectedFPtr.takeError();
153 
154   if (options.dumpObjectFile)
155     engine->dumpToObjectFile(options.objectFilename.empty()
156                                  ? options.inputFilename + ".o"
157                                  : options.objectFilename);
158 
159   void (*fptr)(void **) = *expectedFPtr;
160   (*fptr)(args);
161 
162   return Error::success();
163 }
164 
165 static Error compileAndExecuteVoidFunction(
166     Options &options, ModuleOp module, StringRef entryPoint,
167     std::function<llvm::Error(llvm::Module *)> transformer) {
168   auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
169   if (!mainFunction || mainFunction.getBlocks().empty())
170     return make_string_error("entry point not found");
171   void *empty = nullptr;
172   return compileAndExecute(options, module, entryPoint, transformer, &empty);
173 }
174 
175 static Error compileAndExecuteSingleFloatReturnFunction(
176     Options &options, ModuleOp module, StringRef entryPoint,
177     std::function<llvm::Error(llvm::Module *)> transformer) {
178   auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
179   if (!mainFunction || mainFunction.isExternal())
180     return make_string_error("entry point not found");
181 
182   if (mainFunction.getType().getFunctionNumParams() != 0)
183     return make_string_error("function inputs not supported");
184 
185   if (!mainFunction.getType().getFunctionResultType().isFloatTy())
186     return make_string_error("only single llvm.f32 function result supported");
187 
188   float res;
189   struct {
190     void *data;
191   } data;
192   data.data = &res;
193   if (auto error = compileAndExecute(options, module, entryPoint, transformer,
194                                      (void **)&data))
195     return error;
196 
197   // Intentional printing of the output so we can test.
198   llvm::outs() << res << '\n';
199   return Error::success();
200 }
201 
202 /// Entry point for all CPU runners. Expects the common argc/argv arguments for
203 /// standard C++ main functions and an mlirTransformer.
204 /// The latter is applied after parsing the input into MLIR IR and before
205 /// passing the MLIR module to the ExecutionEngine.
206 int mlir::JitRunnerMain(
207     int argc, char **argv,
208     function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer) {
209   // Create the options struct containing the command line options for the
210   // runner. This must come before the command line options are parsed.
211   Options options;
212   llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
213 
214   Optional<unsigned> optLevel = getCommandLineOptLevel(options);
215   SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
216       options.optO0, options.optO1, options.optO2, options.optO3};
217   unsigned optCLIPosition = 0;
218   // Determine if there is an optimization flag present, and its CLI position
219   // (optCLIPosition).
220   for (unsigned j = 0; j < 4; ++j) {
221     auto &flag = optFlags[j].get();
222     if (flag) {
223       optCLIPosition = flag.getPosition();
224       break;
225     }
226   }
227   // Generate vector of pass information, plus the index at which we should
228   // insert any optimization passes in that vector (optPosition).
229   SmallVector<const llvm::PassInfo *, 4> passes;
230   unsigned optPosition = 0;
231   for (unsigned i = 0, e = options.llvmPasses.size(); i < e; ++i) {
232     passes.push_back(options.llvmPasses[i]);
233     if (optCLIPosition < options.llvmPasses.getPosition(i)) {
234       optPosition = i;
235       optCLIPosition = UINT_MAX; // To ensure we never insert again
236     }
237   }
238 
239   MLIRContext context;
240   auto m = parseMLIRInput(options.inputFilename, &context);
241   if (!m) {
242     llvm::errs() << "could not parse the input IR\n";
243     return 1;
244   }
245 
246   if (mlirTransformer)
247     if (failed(mlirTransformer(m.get())))
248       return EXIT_FAILURE;
249 
250   auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
251   if (!tmBuilderOrError) {
252     llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
253     return EXIT_FAILURE;
254   }
255   auto tmOrError = tmBuilderOrError->createTargetMachine();
256   if (!tmOrError) {
257     llvm::errs() << "Failed to create a TargetMachine for the host\n";
258     return EXIT_FAILURE;
259   }
260 
261   auto transformer = mlir::makeLLVMPassesTransformer(
262       passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition);
263 
264   // Get the function used to compile and execute the module.
265   using CompileAndExecuteFnT =
266       Error (*)(Options &, ModuleOp, StringRef,
267                 std::function<llvm::Error(llvm::Module *)>);
268   auto compileAndExecuteFn =
269       llvm::StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
270           .Case("f32", compileAndExecuteSingleFloatReturnFunction)
271           .Case("void", compileAndExecuteVoidFunction)
272           .Default(nullptr);
273 
274   Error error =
275       compileAndExecuteFn
276           ? compileAndExecuteFn(options, m.get(),
277                                 options.mainFuncName.getValue(), transformer)
278           : make_string_error("unsupported function type");
279 
280   int exitCode = EXIT_SUCCESS;
281   llvm::handleAllErrors(std::move(error),
282                         [&exitCode](const llvm::ErrorInfoBase &info) {
283                           llvm::errs() << "Error: ";
284                           info.log(llvm::errs());
285                           llvm::errs() << '\n';
286                           exitCode = EXIT_FAILURE;
287                         });
288 
289   return exitCode;
290 }
291