1 //===- IRPrinting.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "PassDetail.h"
10 #include "mlir/IR/SymbolTable.h"
11 #include "mlir/Pass/PassManager.h"
12 #include "llvm/Support/Format.h"
13 #include "llvm/Support/FormatVariadic.h"
14 #include "llvm/Support/SHA1.h"
15
16 using namespace mlir;
17 using namespace mlir::detail;
18
19 namespace {
20 //===----------------------------------------------------------------------===//
21 // OperationFingerPrint
22 //===----------------------------------------------------------------------===//
23
24 /// A unique fingerprint for a specific operation, and all of it's internal
25 /// operations.
26 class OperationFingerPrint {
27 public:
OperationFingerPrint(Operation * topOp)28 OperationFingerPrint(Operation *topOp) {
29 llvm::SHA1 hasher;
30
31 // Hash each of the operations based upon their mutable bits:
32 topOp->walk([&](Operation *op) {
33 // - Operation pointer
34 addDataToHash(hasher, op);
35 // - Attributes
36 addDataToHash(hasher, op->getAttrDictionary());
37 // - Blocks in Regions
38 for (Region ®ion : op->getRegions()) {
39 for (Block &block : region) {
40 addDataToHash(hasher, &block);
41 for (BlockArgument arg : block.getArguments())
42 addDataToHash(hasher, arg);
43 }
44 }
45 // - Location
46 addDataToHash(hasher, op->getLoc().getAsOpaquePointer());
47 // - Operands
48 for (Value operand : op->getOperands())
49 addDataToHash(hasher, operand);
50 // - Successors
51 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
52 addDataToHash(hasher, op->getSuccessor(i));
53 });
54 hash = hasher.result();
55 }
56
operator ==(const OperationFingerPrint & other) const57 bool operator==(const OperationFingerPrint &other) const {
58 return hash == other.hash;
59 }
operator !=(const OperationFingerPrint & other) const60 bool operator!=(const OperationFingerPrint &other) const {
61 return !(*this == other);
62 }
63
64 private:
65 template <typename T>
addDataToHash(llvm::SHA1 & hasher,const T & data)66 void addDataToHash(llvm::SHA1 &hasher, const T &data) {
67 hasher.update(
68 ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
69 }
70
71 std::array<uint8_t, 20> hash;
72 };
73
74 //===----------------------------------------------------------------------===//
75 // IRPrinter
76 //===----------------------------------------------------------------------===//
77
78 class IRPrinterInstrumentation : public PassInstrumentation {
79 public:
IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config)80 IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config)
81 : config(std::move(config)) {}
82
83 private:
84 /// Instrumentation hooks.
85 void runBeforePass(Pass *pass, Operation *op) override;
86 void runAfterPass(Pass *pass, Operation *op) override;
87 void runAfterPassFailed(Pass *pass, Operation *op) override;
88
89 /// Configuration to use.
90 std::unique_ptr<PassManager::IRPrinterConfig> config;
91
92 /// The following is a set of fingerprints for operations that are currently
93 /// being operated on in a pass. This field is only used when the
94 /// configuration asked for change detection.
95 DenseMap<Pass *, OperationFingerPrint> beforePassFingerPrints;
96 };
97 } // namespace
98
printIR(Operation * op,bool printModuleScope,raw_ostream & out,OpPrintingFlags flags)99 static void printIR(Operation *op, bool printModuleScope, raw_ostream &out,
100 OpPrintingFlags flags) {
101 // Otherwise, check to see if we are not printing at module scope.
102 if (!printModuleScope)
103 return op->print(out << " //----- //\n",
104 op->getBlock() ? flags.useLocalScope() : flags);
105
106 // Otherwise, we are printing at module scope.
107 out << " ('" << op->getName() << "' operation";
108 if (auto symbolName =
109 op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()))
110 out << ": @" << symbolName.getValue();
111 out << ") //----- //\n";
112
113 // Find the top-level operation.
114 auto *topLevelOp = op;
115 while (auto *parentOp = topLevelOp->getParentOp())
116 topLevelOp = parentOp;
117 topLevelOp->print(out, flags);
118 }
119
120 /// Instrumentation hooks.
runBeforePass(Pass * pass,Operation * op)121 void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) {
122 if (isa<OpToOpPassAdaptor>(pass))
123 return;
124 // If the config asked to detect changes, record the current fingerprint.
125 if (config->shouldPrintAfterOnlyOnChange())
126 beforePassFingerPrints.try_emplace(pass, op);
127
128 config->printBeforeIfEnabled(pass, op, [&](raw_ostream &out) {
129 out << "// -----// IR Dump Before " << pass->getName() << " ("
130 << pass->getArgument() << ")";
131 printIR(op, config->shouldPrintAtModuleScope(), out,
132 config->getOpPrintingFlags());
133 out << "\n\n";
134 });
135 }
136
runAfterPass(Pass * pass,Operation * op)137 void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) {
138 if (isa<OpToOpPassAdaptor>(pass))
139 return;
140
141 // Check to see if we are only printing on failure.
142 if (config->shouldPrintAfterOnlyOnFailure())
143 return;
144
145 // If the config asked to detect changes, compare the current fingerprint with
146 // the previous.
147 if (config->shouldPrintAfterOnlyOnChange()) {
148 auto fingerPrintIt = beforePassFingerPrints.find(pass);
149 assert(fingerPrintIt != beforePassFingerPrints.end() &&
150 "expected valid fingerprint");
151 // If the fingerprints are the same, we don't print the IR.
152 if (fingerPrintIt->second == OperationFingerPrint(op)) {
153 beforePassFingerPrints.erase(fingerPrintIt);
154 return;
155 }
156 beforePassFingerPrints.erase(fingerPrintIt);
157 }
158
159 config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) {
160 out << "// -----// IR Dump After " << pass->getName() << " ("
161 << pass->getArgument() << ")";
162 printIR(op, config->shouldPrintAtModuleScope(), out,
163 config->getOpPrintingFlags());
164 out << "\n\n";
165 });
166 }
167
runAfterPassFailed(Pass * pass,Operation * op)168 void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
169 if (isa<OpToOpPassAdaptor>(pass))
170 return;
171 if (config->shouldPrintAfterOnlyOnChange())
172 beforePassFingerPrints.erase(pass);
173
174 config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) {
175 out << formatv("// -----// IR Dump After {0} Failed ({1})", pass->getName(),
176 pass->getArgument());
177 printIR(op, config->shouldPrintAtModuleScope(), out, OpPrintingFlags());
178 out << "\n\n";
179 });
180 }
181
182 //===----------------------------------------------------------------------===//
183 // IRPrinterConfig
184 //===----------------------------------------------------------------------===//
185
186 /// Initialize the configuration.
IRPrinterConfig(bool printModuleScope,bool printAfterOnlyOnChange,bool printAfterOnlyOnFailure,OpPrintingFlags opPrintingFlags)187 PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope,
188 bool printAfterOnlyOnChange,
189 bool printAfterOnlyOnFailure,
190 OpPrintingFlags opPrintingFlags)
191 : printModuleScope(printModuleScope),
192 printAfterOnlyOnChange(printAfterOnlyOnChange),
193 printAfterOnlyOnFailure(printAfterOnlyOnFailure),
194 opPrintingFlags(opPrintingFlags) {}
195 PassManager::IRPrinterConfig::~IRPrinterConfig() = default;
196
197 /// A hook that may be overridden by a derived config that checks if the IR
198 /// of 'operation' should be dumped *before* the pass 'pass' has been
199 /// executed. If the IR should be dumped, 'printCallback' should be invoked
200 /// with the stream to dump into.
printBeforeIfEnabled(Pass * pass,Operation * operation,PrintCallbackFn printCallback)201 void PassManager::IRPrinterConfig::printBeforeIfEnabled(
202 Pass *pass, Operation *operation, PrintCallbackFn printCallback) {
203 // By default, never print.
204 }
205
206 /// A hook that may be overridden by a derived config that checks if the IR
207 /// of 'operation' should be dumped *after* the pass 'pass' has been
208 /// executed. If the IR should be dumped, 'printCallback' should be invoked
209 /// with the stream to dump into.
printAfterIfEnabled(Pass * pass,Operation * operation,PrintCallbackFn printCallback)210 void PassManager::IRPrinterConfig::printAfterIfEnabled(
211 Pass *pass, Operation *operation, PrintCallbackFn printCallback) {
212 // By default, never print.
213 }
214
215 //===----------------------------------------------------------------------===//
216 // PassManager
217 //===----------------------------------------------------------------------===//
218
219 namespace {
220 /// Simple wrapper config that allows for the simpler interface defined above.
221 struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig {
BasicIRPrinterConfig__anonc32b97650611::BasicIRPrinterConfig222 BasicIRPrinterConfig(
223 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
224 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
225 bool printModuleScope, bool printAfterOnlyOnChange,
226 bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags,
227 raw_ostream &out)
228 : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange,
229 printAfterOnlyOnFailure, opPrintingFlags),
230 shouldPrintBeforePass(std::move(shouldPrintBeforePass)),
231 shouldPrintAfterPass(std::move(shouldPrintAfterPass)), out(out) {
232 assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) &&
233 "expected at least one valid filter function");
234 }
235
printBeforeIfEnabled__anonc32b97650611::BasicIRPrinterConfig236 void printBeforeIfEnabled(Pass *pass, Operation *operation,
237 PrintCallbackFn printCallback) final {
238 if (shouldPrintBeforePass && shouldPrintBeforePass(pass, operation))
239 printCallback(out);
240 }
241
printAfterIfEnabled__anonc32b97650611::BasicIRPrinterConfig242 void printAfterIfEnabled(Pass *pass, Operation *operation,
243 PrintCallbackFn printCallback) final {
244 if (shouldPrintAfterPass && shouldPrintAfterPass(pass, operation))
245 printCallback(out);
246 }
247
248 /// Filter functions for before and after pass execution.
249 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass;
250 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass;
251
252 /// The stream to output to.
253 raw_ostream &out;
254 };
255 } // namespace
256
257 /// Add an instrumentation to print the IR before and after pass execution,
258 /// using the provided configuration.
enableIRPrinting(std::unique_ptr<IRPrinterConfig> config)259 void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) {
260 if (config->shouldPrintAtModuleScope() &&
261 getContext()->isMultithreadingEnabled())
262 llvm::report_fatal_error("IR printing can't be setup on a pass-manager "
263 "without disabling multi-threading first.");
264 addInstrumentation(
265 std::make_unique<IRPrinterInstrumentation>(std::move(config)));
266 }
267
268 /// Add an instrumentation to print the IR before and after pass execution.
enableIRPrinting(std::function<bool (Pass *,Operation *)> shouldPrintBeforePass,std::function<bool (Pass *,Operation *)> shouldPrintAfterPass,bool printModuleScope,bool printAfterOnlyOnChange,bool printAfterOnlyOnFailure,raw_ostream & out,OpPrintingFlags opPrintingFlags)269 void PassManager::enableIRPrinting(
270 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
271 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
272 bool printModuleScope, bool printAfterOnlyOnChange,
273 bool printAfterOnlyOnFailure, raw_ostream &out,
274 OpPrintingFlags opPrintingFlags) {
275 enableIRPrinting(std::make_unique<BasicIRPrinterConfig>(
276 std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass),
277 printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure,
278 opPrintingFlags, out));
279 }
280