1 //===- jit-runner.cpp - MLIR CPU Execution Driver Library -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a library that provides a shared implementation for command line
10 // utilities that execute an MLIR file on the CPU by translating MLIR to LLVM
11 // IR before JIT-compiling and executing the latter.
12 //
13 // The translation can be customized by providing an MLIR to MLIR
14 // transformation.
15 //===----------------------------------------------------------------------===//
16 
17 #include "mlir/ExecutionEngine/JitRunner.h"
18 
19 #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
20 #include "mlir/ExecutionEngine/ExecutionEngine.h"
21 #include "mlir/ExecutionEngine/OptUtils.h"
22 #include "mlir/IR/MLIRContext.h"
23 #include "mlir/IR/Module.h"
24 #include "mlir/IR/StandardTypes.h"
25 #include "mlir/Parser.h"
26 #include "mlir/Support/FileUtilities.h"
27 
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
30 #include "llvm/IR/IRBuilder.h"
31 #include "llvm/IR/LLVMContext.h"
32 #include "llvm/IR/LegacyPassNameParser.h"
33 #include "llvm/IR/Module.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Support/FileUtilities.h"
36 #include "llvm/Support/SourceMgr.h"
37 #include "llvm/Support/StringSaver.h"
38 #include "llvm/Support/ToolOutputFile.h"
39 #include <cstdint>
40 #include <numeric>
41 
42 using namespace mlir;
43 using llvm::Error;
44 
45 namespace {
46 /// This options struct prevents the need for global static initializers, and
47 /// is only initialized if the JITRunner is invoked.
48 struct Options {
49   llvm::cl::opt<std::string> inputFilename{llvm::cl::Positional,
50                                            llvm::cl::desc("<input file>"),
51                                            llvm::cl::init("-")};
52   llvm::cl::opt<std::string> mainFuncName{
53       "e", llvm::cl::desc("The function to be called"),
54       llvm::cl::value_desc("<function name>"), llvm::cl::init("main")};
55   llvm::cl::opt<std::string> mainFuncType{
56       "entry-point-result",
57       llvm::cl::desc("Textual description of the function type to be called"),
58       llvm::cl::value_desc("f32 | i32 | i64 | void"), llvm::cl::init("f32")};
59 
60   llvm::cl::OptionCategory optFlags{"opt-like flags"};
61 
62   // CLI list of pass information
63   llvm::cl::list<const llvm::PassInfo *, bool, llvm::PassNameParser> llvmPasses{
64       llvm::cl::desc("LLVM optimizing passes to run"), llvm::cl::cat(optFlags)};
65 
66   // CLI variables for -On options.
67   llvm::cl::opt<bool> optO0{"O0",
68                             llvm::cl::desc("Run opt passes and codegen at O0"),
69                             llvm::cl::cat(optFlags)};
70   llvm::cl::opt<bool> optO1{"O1",
71                             llvm::cl::desc("Run opt passes and codegen at O1"),
72                             llvm::cl::cat(optFlags)};
73   llvm::cl::opt<bool> optO2{"O2",
74                             llvm::cl::desc("Run opt passes and codegen at O2"),
75                             llvm::cl::cat(optFlags)};
76   llvm::cl::opt<bool> optO3{"O3",
77                             llvm::cl::desc("Run opt passes and codegen at O3"),
78                             llvm::cl::cat(optFlags)};
79 
80   llvm::cl::OptionCategory clOptionsCategory{"linking options"};
81   llvm::cl::list<std::string> clSharedLibs{
82       "shared-libs", llvm::cl::desc("Libraries to link dynamically"),
83       llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
84       llvm::cl::cat(clOptionsCategory)};
85 
86   /// CLI variables for debugging.
87   llvm::cl::opt<bool> dumpObjectFile{
88       "dump-object-file",
89       llvm::cl::desc("Dump JITted-compiled object to file specified with "
90                      "-object-filename (<input file>.o by default).")};
91 
92   llvm::cl::opt<std::string> objectFilename{
93       "object-filename",
94       llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
95 };
96 } // end anonymous namespace
97 
98 static OwningModuleRef parseMLIRInput(StringRef inputFilename,
99                                       MLIRContext *context) {
100   // Set up the input file.
101   std::string errorMessage;
102   auto file = openInputFile(inputFilename, &errorMessage);
103   if (!file) {
104     llvm::errs() << errorMessage << "\n";
105     return nullptr;
106   }
107 
108   llvm::SourceMgr sourceMgr;
109   sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
110   return OwningModuleRef(parseSourceFile(sourceMgr, context));
111 }
112 
113 static inline Error make_string_error(const Twine &message) {
114   return llvm::make_error<llvm::StringError>(message.str(),
115                                              llvm::inconvertibleErrorCode());
116 }
117 
118 static Optional<unsigned> getCommandLineOptLevel(Options &options) {
119   Optional<unsigned> optLevel;
120   SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
121       options.optO0, options.optO1, options.optO2, options.optO3};
122 
123   // Determine if there is an optimization flag present.
124   for (unsigned j = 0; j < 4; ++j) {
125     auto &flag = optFlags[j].get();
126     if (flag) {
127       optLevel = j;
128       break;
129     }
130   }
131   return optLevel;
132 }
133 
134 // JIT-compile the given module and run "entryPoint" with "args" as arguments.
135 static Error
136 compileAndExecute(Options &options, ModuleOp module, StringRef entryPoint,
137                   std::function<llvm::Error(llvm::Module *)> transformer,
138                   void **args) {
139   Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel;
140   if (auto clOptLevel = getCommandLineOptLevel(options))
141     jitCodeGenOptLevel =
142         static_cast<llvm::CodeGenOpt::Level>(clOptLevel.getValue());
143   SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(),
144                                  options.clSharedLibs.end());
145   auto expectedEngine = mlir::ExecutionEngine::create(module, transformer,
146                                                       jitCodeGenOptLevel, libs);
147   if (!expectedEngine)
148     return expectedEngine.takeError();
149 
150   auto engine = std::move(*expectedEngine);
151   auto expectedFPtr = engine->lookup(entryPoint);
152   if (!expectedFPtr)
153     return expectedFPtr.takeError();
154 
155   if (options.dumpObjectFile)
156     engine->dumpToObjectFile(options.objectFilename.empty()
157                                  ? options.inputFilename + ".o"
158                                  : options.objectFilename);
159 
160   void (*fptr)(void **) = *expectedFPtr;
161   (*fptr)(args);
162 
163   return Error::success();
164 }
165 
166 static Error compileAndExecuteVoidFunction(
167     Options &options, ModuleOp module, StringRef entryPoint,
168     std::function<llvm::Error(llvm::Module *)> transformer) {
169   auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
170   if (!mainFunction || mainFunction.empty())
171     return make_string_error("entry point not found");
172   void *empty = nullptr;
173   return compileAndExecute(options, module, entryPoint, transformer, &empty);
174 }
175 
176 template <typename Type>
177 Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction);
178 template <>
179 Error checkCompatibleReturnType<int32_t>(LLVM::LLVMFuncOp mainFunction) {
180   if (!mainFunction.getType().getFunctionResultType().isIntegerTy(32))
181     return make_string_error("only single llvm.i32 function result supported");
182   return Error::success();
183 }
184 template <>
185 Error checkCompatibleReturnType<int64_t>(LLVM::LLVMFuncOp mainFunction) {
186   if (!mainFunction.getType().getFunctionResultType().isIntegerTy(64))
187     return make_string_error("only single llvm.i64 function result supported");
188   return Error::success();
189 }
190 template <>
191 Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
192   if (!mainFunction.getType().getFunctionResultType().isFloatTy())
193     return make_string_error("only single llvm.f32 function result supported");
194   return Error::success();
195 }
196 template <typename Type>
197 Error compileAndExecuteSingleReturnFunction(
198     Options &options, ModuleOp module, StringRef entryPoint,
199     std::function<llvm::Error(llvm::Module *)> transformer) {
200   auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
201   if (!mainFunction || mainFunction.isExternal())
202     return make_string_error("entry point not found");
203 
204   if (mainFunction.getType().getFunctionNumParams() != 0)
205     return make_string_error("function inputs not supported");
206 
207   if (Error error = checkCompatibleReturnType<Type>(mainFunction))
208     return error;
209 
210   Type res;
211   struct {
212     void *data;
213   } data;
214   data.data = &res;
215   if (auto error = compileAndExecute(options, module, entryPoint, transformer,
216                                      (void **)&data))
217     return error;
218 
219   // Intentional printing of the output so we can test.
220   llvm::outs() << res << '\n';
221 
222   return Error::success();
223 }
224 
225 /// Entry point for all CPU runners. Expects the common argc/argv
226 /// arguments for standard C++ main functions and an mlirTransformer.
227 /// The latter is applied after parsing the input into MLIR IR and
228 /// before passing the MLIR module to the ExecutionEngine.
229 int mlir::JitRunnerMain(
230     int argc, char **argv,
231     function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer) {
232   // Create the options struct containing the command line options for the
233   // runner. This must come before the command line options are parsed.
234   Options options;
235   llvm::cl::ParseCommandLineOptions(argc, argv, "MLIR CPU execution driver\n");
236 
237   Optional<unsigned> optLevel = getCommandLineOptLevel(options);
238   SmallVector<std::reference_wrapper<llvm::cl::opt<bool>>, 4> optFlags{
239       options.optO0, options.optO1, options.optO2, options.optO3};
240   unsigned optCLIPosition = 0;
241   // Determine if there is an optimization flag present, and its CLI position
242   // (optCLIPosition).
243   for (unsigned j = 0; j < 4; ++j) {
244     auto &flag = optFlags[j].get();
245     if (flag) {
246       optCLIPosition = flag.getPosition();
247       break;
248     }
249   }
250   // Generate vector of pass information, plus the index at which we should
251   // insert any optimization passes in that vector (optPosition).
252   SmallVector<const llvm::PassInfo *, 4> passes;
253   unsigned optPosition = 0;
254   for (unsigned i = 0, e = options.llvmPasses.size(); i < e; ++i) {
255     passes.push_back(options.llvmPasses[i]);
256     if (optCLIPosition < options.llvmPasses.getPosition(i)) {
257       optPosition = i;
258       optCLIPosition = UINT_MAX; // To ensure we never insert again
259     }
260   }
261 
262   MLIRContext context(/*loadAllDialects=*/false);
263   registerAllDialects(&context);
264 
265   auto m = parseMLIRInput(options.inputFilename, &context);
266   if (!m) {
267     llvm::errs() << "could not parse the input IR\n";
268     return 1;
269   }
270 
271   if (mlirTransformer)
272     if (failed(mlirTransformer(m.get())))
273       return EXIT_FAILURE;
274 
275   auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
276   if (!tmBuilderOrError) {
277     llvm::errs() << "Failed to create a JITTargetMachineBuilder for the host\n";
278     return EXIT_FAILURE;
279   }
280   auto tmOrError = tmBuilderOrError->createTargetMachine();
281   if (!tmOrError) {
282     llvm::errs() << "Failed to create a TargetMachine for the host\n";
283     return EXIT_FAILURE;
284   }
285 
286   auto transformer = mlir::makeLLVMPassesTransformer(
287       passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition);
288 
289   // Get the function used to compile and execute the module.
290   using CompileAndExecuteFnT =
291       Error (*)(Options &, ModuleOp, StringRef,
292                 std::function<llvm::Error(llvm::Module *)>);
293   auto compileAndExecuteFn =
294       llvm::StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
295           .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
296           .Case("i64", compileAndExecuteSingleReturnFunction<int64_t>)
297           .Case("f32", compileAndExecuteSingleReturnFunction<float>)
298           .Case("void", compileAndExecuteVoidFunction)
299           .Default(nullptr);
300 
301   Error error =
302       compileAndExecuteFn
303           ? compileAndExecuteFn(options, m.get(),
304                                 options.mainFuncName.getValue(), transformer)
305           : make_string_error("unsupported function type");
306 
307   int exitCode = EXIT_SUCCESS;
308   llvm::handleAllErrors(std::move(error),
309                         [&exitCode](const llvm::ErrorInfoBase &info) {
310                           llvm::errs() << "Error: ";
311                           info.log(llvm::errs());
312                           llvm::errs() << '\n';
313                           exitCode = EXIT_FAILURE;
314                         });
315 
316   return exitCode;
317 }
318