1 //===- MlirOptMain.cpp - MLIR Optimizer Driver ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a utility that runs an optimization pass and prints the result back
10 // out. It is designed to support unit testing.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Tools/mlir-opt/MlirOptMain.h"
15 #include "mlir/IR/AsmState.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/IR/Location.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/Parser/Parser.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Support/DebugCounter.h"
26 #include "mlir/Support/FileUtilities.h"
27 #include "mlir/Support/Timing.h"
28 #include "mlir/Support/ToolUtilities.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/FileUtilities.h"
31 #include "llvm/Support/InitLLVM.h"
32 #include "llvm/Support/Regex.h"
33 #include "llvm/Support/SourceMgr.h"
34 #include "llvm/Support/StringSaver.h"
35 #include "llvm/Support/ThreadPool.h"
36 #include "llvm/Support/ToolOutputFile.h"
37 
38 using namespace mlir;
39 using namespace llvm;
40 
41 /// Perform the actions on the input file indicated by the command line flags
42 /// within the specified context.
43 ///
44 /// This typically parses the main source file, runs zero or more optimization
45 /// passes, then prints the output.
46 ///
performActions(raw_ostream & os,bool verifyDiagnostics,bool verifyPasses,SourceMgr & sourceMgr,MLIRContext * context,PassPipelineFn passManagerSetupFn)47 static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
48                                     bool verifyPasses, SourceMgr &sourceMgr,
49                                     MLIRContext *context,
50                                     PassPipelineFn passManagerSetupFn) {
51   DefaultTimingManager tm;
52   applyDefaultTimingManagerCLOptions(tm);
53   TimingScope timing = tm.getRootScope();
54 
55   // Disable multi-threading when parsing the input file. This removes the
56   // unnecessary/costly context synchronization when parsing.
57   bool wasThreadingEnabled = context->isMultithreadingEnabled();
58   context->disableMultithreading();
59 
60   // Prepare the pass manager and apply any command line options.
61   PassManager pm(context, OpPassManager::Nesting::Implicit);
62   pm.enableVerifier(verifyPasses);
63   applyPassManagerCLOptions(pm);
64   pm.enableTiming(timing);
65 
66   // Prepare the parser config, and attach any useful/necessary resource
67   // handlers.
68   ParserConfig config(context);
69   attachPassReproducerAsmResource(config, pm, wasThreadingEnabled);
70 
71   // Parse the input file and reset the context threading state.
72   TimingScope parserTiming = timing.nest("Parser");
73   OwningOpRef<ModuleOp> module(parseSourceFile<ModuleOp>(sourceMgr, config));
74   context->enableMultithreading(wasThreadingEnabled);
75   if (!module)
76     return failure();
77   parserTiming.stop();
78 
79   // Callback to build the pipeline.
80   if (failed(passManagerSetupFn(pm)))
81     return failure();
82 
83   // Run the pipeline.
84   if (failed(pm.run(*module)))
85     return failure();
86 
87   // Print the output.
88   TimingScope outputTiming = timing.nest("Output");
89   module->print(os);
90   os << '\n';
91   return success();
92 }
93 
94 /// Parses the memory buffer.  If successfully, run a series of passes against
95 /// it and print the result.
96 static LogicalResult
processBuffer(raw_ostream & os,std::unique_ptr<MemoryBuffer> ownedBuffer,bool verifyDiagnostics,bool verifyPasses,bool allowUnregisteredDialects,bool preloadDialectsInContext,PassPipelineFn passManagerSetupFn,DialectRegistry & registry,llvm::ThreadPool * threadPool)97 processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
98               bool verifyDiagnostics, bool verifyPasses,
99               bool allowUnregisteredDialects, bool preloadDialectsInContext,
100               PassPipelineFn passManagerSetupFn, DialectRegistry &registry,
101               llvm::ThreadPool *threadPool) {
102   // Tell sourceMgr about this buffer, which is what the parser will pick up.
103   SourceMgr sourceMgr;
104   sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
105 
106   // Create a context just for the current buffer. Disable threading on creation
107   // since we'll inject the thread-pool separately.
108   MLIRContext context(registry, MLIRContext::Threading::DISABLED);
109   if (threadPool)
110     context.setThreadPool(*threadPool);
111 
112   // Parse the input file.
113   if (preloadDialectsInContext)
114     context.loadAllAvailableDialects();
115   context.allowUnregisteredDialects(allowUnregisteredDialects);
116   if (verifyDiagnostics)
117     context.printOpOnDiagnostic(false);
118   context.getDebugActionManager().registerActionHandler<DebugCounter>();
119 
120   // If we are in verify diagnostics mode then we have a lot of work to do,
121   // otherwise just perform the actions without worrying about it.
122   if (!verifyDiagnostics) {
123     SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
124     return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr,
125                           &context, passManagerSetupFn);
126   }
127 
128   SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
129 
130   // Do any processing requested by command line flags.  We don't care whether
131   // these actions succeed or fail, we only care what diagnostics they produce
132   // and whether they match our expectations.
133   (void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context,
134                        passManagerSetupFn);
135 
136   // Verify the diagnostic handler to make sure that each of the diagnostics
137   // matched.
138   return sourceMgrHandler.verify();
139 }
140 
MlirOptMain(raw_ostream & outputStream,std::unique_ptr<MemoryBuffer> buffer,PassPipelineFn passManagerSetupFn,DialectRegistry & registry,bool splitInputFile,bool verifyDiagnostics,bool verifyPasses,bool allowUnregisteredDialects,bool preloadDialectsInContext)141 LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
142                                 std::unique_ptr<MemoryBuffer> buffer,
143                                 PassPipelineFn passManagerSetupFn,
144                                 DialectRegistry &registry, bool splitInputFile,
145                                 bool verifyDiagnostics, bool verifyPasses,
146                                 bool allowUnregisteredDialects,
147                                 bool preloadDialectsInContext) {
148   // The split-input-file mode is a very specific mode that slices the file
149   // up into small pieces and checks each independently.
150   // We use an explicit threadpool to avoid creating and joining/destroying
151   // threads for each of the split.
152   ThreadPool *threadPool = nullptr;
153 
154   // Create a temporary context for the sake of checking if
155   // --mlir-disable-threading was passed on the command line.
156   // We use the thread-pool this context is creating, and avoid
157   // creating any thread when disabled.
158   MLIRContext threadPoolCtx;
159   if (threadPoolCtx.isMultithreadingEnabled())
160     threadPool = &threadPoolCtx.getThreadPool();
161 
162   auto chunkFn = [&](std::unique_ptr<MemoryBuffer> chunkBuffer,
163                      raw_ostream &os) {
164     return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
165                          verifyPasses, allowUnregisteredDialects,
166                          preloadDialectsInContext, passManagerSetupFn, registry,
167                          threadPool);
168   };
169   return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream,
170                                splitInputFile, /*insertMarkerInOutput=*/true);
171 }
172 
MlirOptMain(raw_ostream & outputStream,std::unique_ptr<MemoryBuffer> buffer,const PassPipelineCLParser & passPipeline,DialectRegistry & registry,bool splitInputFile,bool verifyDiagnostics,bool verifyPasses,bool allowUnregisteredDialects,bool preloadDialectsInContext)173 LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
174                                 std::unique_ptr<MemoryBuffer> buffer,
175                                 const PassPipelineCLParser &passPipeline,
176                                 DialectRegistry &registry, bool splitInputFile,
177                                 bool verifyDiagnostics, bool verifyPasses,
178                                 bool allowUnregisteredDialects,
179                                 bool preloadDialectsInContext) {
180   auto passManagerSetupFn = [&](PassManager &pm) {
181     auto errorHandler = [&](const Twine &msg) {
182       emitError(UnknownLoc::get(pm.getContext())) << msg;
183       return failure();
184     };
185     return passPipeline.addToPipeline(pm, errorHandler);
186   };
187   return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn,
188                      registry, splitInputFile, verifyDiagnostics, verifyPasses,
189                      allowUnregisteredDialects, preloadDialectsInContext);
190 }
191 
MlirOptMain(int argc,char ** argv,llvm::StringRef toolName,DialectRegistry & registry,bool preloadDialectsInContext)192 LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
193                                 DialectRegistry &registry,
194                                 bool preloadDialectsInContext) {
195   static cl::opt<std::string> inputFilename(
196       cl::Positional, cl::desc("<input file>"), cl::init("-"));
197 
198   static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
199                                              cl::value_desc("filename"),
200                                              cl::init("-"));
201 
202   static cl::opt<bool> splitInputFile(
203       "split-input-file",
204       cl::desc("Split the input file into pieces and process each "
205                "chunk independently"),
206       cl::init(false));
207 
208   static cl::opt<bool> verifyDiagnostics(
209       "verify-diagnostics",
210       cl::desc("Check that emitted diagnostics match "
211                "expected-* lines on the corresponding line"),
212       cl::init(false));
213 
214   static cl::opt<bool> verifyPasses(
215       "verify-each",
216       cl::desc("Run the verifier after each transformation pass"),
217       cl::init(true));
218 
219   static cl::opt<bool> allowUnregisteredDialects(
220       "allow-unregistered-dialect",
221       cl::desc("Allow operation with no registered dialects"), cl::init(false));
222 
223   static cl::opt<bool> showDialects(
224       "show-dialects", cl::desc("Print the list of registered dialects"),
225       cl::init(false));
226 
227   InitLLVM y(argc, argv);
228 
229   // Register any command line options.
230   registerAsmPrinterCLOptions();
231   registerMLIRContextCLOptions();
232   registerPassManagerCLOptions();
233   registerDefaultTimingManagerCLOptions();
234   DebugCounter::registerCLOptions();
235   PassPipelineCLParser passPipeline("", "Compiler passes to run");
236 
237   // Build the list of dialects as a header for the --help message.
238   std::string helpHeader = (toolName + "\nAvailable Dialects: ").str();
239   {
240     llvm::raw_string_ostream os(helpHeader);
241     interleaveComma(registry.getDialectNames(), os,
242                     [&](auto name) { os << name; });
243   }
244   // Parse pass names in main to ensure static initialization completed.
245   cl::ParseCommandLineOptions(argc, argv, helpHeader);
246 
247   if (showDialects) {
248     llvm::outs() << "Available Dialects:\n";
249     interleave(
250         registry.getDialectNames(), llvm::outs(),
251         [](auto name) { llvm::outs() << name; }, "\n");
252     return success();
253   }
254 
255   // Set up the input file.
256   std::string errorMessage;
257   auto file = openInputFile(inputFilename, &errorMessage);
258   if (!file) {
259     llvm::errs() << errorMessage << "\n";
260     return failure();
261   }
262 
263   auto output = openOutputFile(outputFilename, &errorMessage);
264   if (!output) {
265     llvm::errs() << errorMessage << "\n";
266     return failure();
267   }
268 
269   if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry,
270                          splitInputFile, verifyDiagnostics, verifyPasses,
271                          allowUnregisteredDialects, preloadDialectsInContext)))
272     return failure();
273 
274   // Keep the output file if the invocation of MlirOptMain was successful.
275   output->keep();
276   return success();
277 }
278