1 //===- MlirOptMain.cpp - MLIR Optimizer Driver ----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This is a utility that runs an optimization pass and prints the result back
10 // out. It is designed to support unit testing.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Tools/mlir-opt/MlirOptMain.h"
15 #include "mlir/IR/AsmState.h"
16 #include "mlir/IR/Attributes.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/Diagnostics.h"
19 #include "mlir/IR/Dialect.h"
20 #include "mlir/IR/Location.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/Parser/Parser.h"
23 #include "mlir/Pass/Pass.h"
24 #include "mlir/Pass/PassManager.h"
25 #include "mlir/Support/DebugCounter.h"
26 #include "mlir/Support/FileUtilities.h"
27 #include "mlir/Support/Timing.h"
28 #include "mlir/Support/ToolUtilities.h"
29 #include "llvm/Support/CommandLine.h"
30 #include "llvm/Support/FileUtilities.h"
31 #include "llvm/Support/InitLLVM.h"
32 #include "llvm/Support/Regex.h"
33 #include "llvm/Support/SourceMgr.h"
34 #include "llvm/Support/StringSaver.h"
35 #include "llvm/Support/ThreadPool.h"
36 #include "llvm/Support/ToolOutputFile.h"
37
38 using namespace mlir;
39 using namespace llvm;
40
41 /// Perform the actions on the input file indicated by the command line flags
42 /// within the specified context.
43 ///
44 /// This typically parses the main source file, runs zero or more optimization
45 /// passes, then prints the output.
46 ///
performActions(raw_ostream & os,bool verifyDiagnostics,bool verifyPasses,SourceMgr & sourceMgr,MLIRContext * context,PassPipelineFn passManagerSetupFn)47 static LogicalResult performActions(raw_ostream &os, bool verifyDiagnostics,
48 bool verifyPasses, SourceMgr &sourceMgr,
49 MLIRContext *context,
50 PassPipelineFn passManagerSetupFn) {
51 DefaultTimingManager tm;
52 applyDefaultTimingManagerCLOptions(tm);
53 TimingScope timing = tm.getRootScope();
54
55 // Disable multi-threading when parsing the input file. This removes the
56 // unnecessary/costly context synchronization when parsing.
57 bool wasThreadingEnabled = context->isMultithreadingEnabled();
58 context->disableMultithreading();
59
60 // Prepare the pass manager and apply any command line options.
61 PassManager pm(context, OpPassManager::Nesting::Implicit);
62 pm.enableVerifier(verifyPasses);
63 applyPassManagerCLOptions(pm);
64 pm.enableTiming(timing);
65
66 // Prepare the parser config, and attach any useful/necessary resource
67 // handlers.
68 ParserConfig config(context);
69 attachPassReproducerAsmResource(config, pm, wasThreadingEnabled);
70
71 // Parse the input file and reset the context threading state.
72 TimingScope parserTiming = timing.nest("Parser");
73 OwningOpRef<ModuleOp> module(parseSourceFile<ModuleOp>(sourceMgr, config));
74 context->enableMultithreading(wasThreadingEnabled);
75 if (!module)
76 return failure();
77 parserTiming.stop();
78
79 // Callback to build the pipeline.
80 if (failed(passManagerSetupFn(pm)))
81 return failure();
82
83 // Run the pipeline.
84 if (failed(pm.run(*module)))
85 return failure();
86
87 // Print the output.
88 TimingScope outputTiming = timing.nest("Output");
89 module->print(os);
90 os << '\n';
91 return success();
92 }
93
94 /// Parses the memory buffer. If successfully, run a series of passes against
95 /// it and print the result.
96 static LogicalResult
processBuffer(raw_ostream & os,std::unique_ptr<MemoryBuffer> ownedBuffer,bool verifyDiagnostics,bool verifyPasses,bool allowUnregisteredDialects,bool preloadDialectsInContext,PassPipelineFn passManagerSetupFn,DialectRegistry & registry,llvm::ThreadPool * threadPool)97 processBuffer(raw_ostream &os, std::unique_ptr<MemoryBuffer> ownedBuffer,
98 bool verifyDiagnostics, bool verifyPasses,
99 bool allowUnregisteredDialects, bool preloadDialectsInContext,
100 PassPipelineFn passManagerSetupFn, DialectRegistry ®istry,
101 llvm::ThreadPool *threadPool) {
102 // Tell sourceMgr about this buffer, which is what the parser will pick up.
103 SourceMgr sourceMgr;
104 sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
105
106 // Create a context just for the current buffer. Disable threading on creation
107 // since we'll inject the thread-pool separately.
108 MLIRContext context(registry, MLIRContext::Threading::DISABLED);
109 if (threadPool)
110 context.setThreadPool(*threadPool);
111
112 // Parse the input file.
113 if (preloadDialectsInContext)
114 context.loadAllAvailableDialects();
115 context.allowUnregisteredDialects(allowUnregisteredDialects);
116 if (verifyDiagnostics)
117 context.printOpOnDiagnostic(false);
118 context.getDebugActionManager().registerActionHandler<DebugCounter>();
119
120 // If we are in verify diagnostics mode then we have a lot of work to do,
121 // otherwise just perform the actions without worrying about it.
122 if (!verifyDiagnostics) {
123 SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
124 return performActions(os, verifyDiagnostics, verifyPasses, sourceMgr,
125 &context, passManagerSetupFn);
126 }
127
128 SourceMgrDiagnosticVerifierHandler sourceMgrHandler(sourceMgr, &context);
129
130 // Do any processing requested by command line flags. We don't care whether
131 // these actions succeed or fail, we only care what diagnostics they produce
132 // and whether they match our expectations.
133 (void)performActions(os, verifyDiagnostics, verifyPasses, sourceMgr, &context,
134 passManagerSetupFn);
135
136 // Verify the diagnostic handler to make sure that each of the diagnostics
137 // matched.
138 return sourceMgrHandler.verify();
139 }
140
MlirOptMain(raw_ostream & outputStream,std::unique_ptr<MemoryBuffer> buffer,PassPipelineFn passManagerSetupFn,DialectRegistry & registry,bool splitInputFile,bool verifyDiagnostics,bool verifyPasses,bool allowUnregisteredDialects,bool preloadDialectsInContext)141 LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
142 std::unique_ptr<MemoryBuffer> buffer,
143 PassPipelineFn passManagerSetupFn,
144 DialectRegistry ®istry, bool splitInputFile,
145 bool verifyDiagnostics, bool verifyPasses,
146 bool allowUnregisteredDialects,
147 bool preloadDialectsInContext) {
148 // The split-input-file mode is a very specific mode that slices the file
149 // up into small pieces and checks each independently.
150 // We use an explicit threadpool to avoid creating and joining/destroying
151 // threads for each of the split.
152 ThreadPool *threadPool = nullptr;
153
154 // Create a temporary context for the sake of checking if
155 // --mlir-disable-threading was passed on the command line.
156 // We use the thread-pool this context is creating, and avoid
157 // creating any thread when disabled.
158 MLIRContext threadPoolCtx;
159 if (threadPoolCtx.isMultithreadingEnabled())
160 threadPool = &threadPoolCtx.getThreadPool();
161
162 auto chunkFn = [&](std::unique_ptr<MemoryBuffer> chunkBuffer,
163 raw_ostream &os) {
164 return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
165 verifyPasses, allowUnregisteredDialects,
166 preloadDialectsInContext, passManagerSetupFn, registry,
167 threadPool);
168 };
169 return splitAndProcessBuffer(std::move(buffer), chunkFn, outputStream,
170 splitInputFile, /*insertMarkerInOutput=*/true);
171 }
172
MlirOptMain(raw_ostream & outputStream,std::unique_ptr<MemoryBuffer> buffer,const PassPipelineCLParser & passPipeline,DialectRegistry & registry,bool splitInputFile,bool verifyDiagnostics,bool verifyPasses,bool allowUnregisteredDialects,bool preloadDialectsInContext)173 LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
174 std::unique_ptr<MemoryBuffer> buffer,
175 const PassPipelineCLParser &passPipeline,
176 DialectRegistry ®istry, bool splitInputFile,
177 bool verifyDiagnostics, bool verifyPasses,
178 bool allowUnregisteredDialects,
179 bool preloadDialectsInContext) {
180 auto passManagerSetupFn = [&](PassManager &pm) {
181 auto errorHandler = [&](const Twine &msg) {
182 emitError(UnknownLoc::get(pm.getContext())) << msg;
183 return failure();
184 };
185 return passPipeline.addToPipeline(pm, errorHandler);
186 };
187 return MlirOptMain(outputStream, std::move(buffer), passManagerSetupFn,
188 registry, splitInputFile, verifyDiagnostics, verifyPasses,
189 allowUnregisteredDialects, preloadDialectsInContext);
190 }
191
MlirOptMain(int argc,char ** argv,llvm::StringRef toolName,DialectRegistry & registry,bool preloadDialectsInContext)192 LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
193 DialectRegistry ®istry,
194 bool preloadDialectsInContext) {
195 static cl::opt<std::string> inputFilename(
196 cl::Positional, cl::desc("<input file>"), cl::init("-"));
197
198 static cl::opt<std::string> outputFilename("o", cl::desc("Output filename"),
199 cl::value_desc("filename"),
200 cl::init("-"));
201
202 static cl::opt<bool> splitInputFile(
203 "split-input-file",
204 cl::desc("Split the input file into pieces and process each "
205 "chunk independently"),
206 cl::init(false));
207
208 static cl::opt<bool> verifyDiagnostics(
209 "verify-diagnostics",
210 cl::desc("Check that emitted diagnostics match "
211 "expected-* lines on the corresponding line"),
212 cl::init(false));
213
214 static cl::opt<bool> verifyPasses(
215 "verify-each",
216 cl::desc("Run the verifier after each transformation pass"),
217 cl::init(true));
218
219 static cl::opt<bool> allowUnregisteredDialects(
220 "allow-unregistered-dialect",
221 cl::desc("Allow operation with no registered dialects"), cl::init(false));
222
223 static cl::opt<bool> showDialects(
224 "show-dialects", cl::desc("Print the list of registered dialects"),
225 cl::init(false));
226
227 InitLLVM y(argc, argv);
228
229 // Register any command line options.
230 registerAsmPrinterCLOptions();
231 registerMLIRContextCLOptions();
232 registerPassManagerCLOptions();
233 registerDefaultTimingManagerCLOptions();
234 DebugCounter::registerCLOptions();
235 PassPipelineCLParser passPipeline("", "Compiler passes to run");
236
237 // Build the list of dialects as a header for the --help message.
238 std::string helpHeader = (toolName + "\nAvailable Dialects: ").str();
239 {
240 llvm::raw_string_ostream os(helpHeader);
241 interleaveComma(registry.getDialectNames(), os,
242 [&](auto name) { os << name; });
243 }
244 // Parse pass names in main to ensure static initialization completed.
245 cl::ParseCommandLineOptions(argc, argv, helpHeader);
246
247 if (showDialects) {
248 llvm::outs() << "Available Dialects:\n";
249 interleave(
250 registry.getDialectNames(), llvm::outs(),
251 [](auto name) { llvm::outs() << name; }, "\n");
252 return success();
253 }
254
255 // Set up the input file.
256 std::string errorMessage;
257 auto file = openInputFile(inputFilename, &errorMessage);
258 if (!file) {
259 llvm::errs() << errorMessage << "\n";
260 return failure();
261 }
262
263 auto output = openOutputFile(outputFilename, &errorMessage);
264 if (!output) {
265 llvm::errs() << errorMessage << "\n";
266 return failure();
267 }
268
269 if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry,
270 splitInputFile, verifyDiagnostics, verifyPasses,
271 allowUnregisteredDialects, preloadDialectsInContext)))
272 return failure();
273
274 // Keep the output file if the invocation of MlirOptMain was successful.
275 output->keep();
276 return success();
277 }
278