1 //===- IRPrinting.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/IR/SymbolTable.h"
11 #include "mlir/Pass/PassManager.h"
12 #include "llvm/Support/Format.h"
13 #include "llvm/Support/FormatVariadic.h"
14 #include "llvm/Support/SHA1.h"
15 
16 using namespace mlir;
17 using namespace mlir::detail;
18 
19 namespace {
20 //===----------------------------------------------------------------------===//
21 // OperationFingerPrint
22 //===----------------------------------------------------------------------===//
23 
24 /// A unique fingerprint for a specific operation, and all of it's internal
25 /// operations.
26 class OperationFingerPrint {
27 public:
OperationFingerPrint(Operation * topOp)28   OperationFingerPrint(Operation *topOp) {
29     llvm::SHA1 hasher;
30 
31     // Hash each of the operations based upon their mutable bits:
32     topOp->walk([&](Operation *op) {
33       //   - Operation pointer
34       addDataToHash(hasher, op);
35       //   - Attributes
36       addDataToHash(hasher, op->getAttrDictionary());
37       //   - Blocks in Regions
38       for (Region &region : op->getRegions()) {
39         for (Block &block : region) {
40           addDataToHash(hasher, &block);
41           for (BlockArgument arg : block.getArguments())
42             addDataToHash(hasher, arg);
43         }
44       }
45       //   - Location
46       addDataToHash(hasher, op->getLoc().getAsOpaquePointer());
47       //   - Operands
48       for (Value operand : op->getOperands())
49         addDataToHash(hasher, operand);
50       //   - Successors
51       for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
52         addDataToHash(hasher, op->getSuccessor(i));
53     });
54     hash = hasher.result();
55   }
56 
operator ==(const OperationFingerPrint & other) const57   bool operator==(const OperationFingerPrint &other) const {
58     return hash == other.hash;
59   }
operator !=(const OperationFingerPrint & other) const60   bool operator!=(const OperationFingerPrint &other) const {
61     return !(*this == other);
62   }
63 
64 private:
65   template <typename T>
addDataToHash(llvm::SHA1 & hasher,const T & data)66   void addDataToHash(llvm::SHA1 &hasher, const T &data) {
67     hasher.update(
68         ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
69   }
70 
71   std::array<uint8_t, 20> hash;
72 };
73 
74 //===----------------------------------------------------------------------===//
75 // IRPrinter
76 //===----------------------------------------------------------------------===//
77 
78 class IRPrinterInstrumentation : public PassInstrumentation {
79 public:
IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config)80   IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config)
81       : config(std::move(config)) {}
82 
83 private:
84   /// Instrumentation hooks.
85   void runBeforePass(Pass *pass, Operation *op) override;
86   void runAfterPass(Pass *pass, Operation *op) override;
87   void runAfterPassFailed(Pass *pass, Operation *op) override;
88 
89   /// Configuration to use.
90   std::unique_ptr<PassManager::IRPrinterConfig> config;
91 
92   /// The following is a set of fingerprints for operations that are currently
93   /// being operated on in a pass. This field is only used when the
94   /// configuration asked for change detection.
95   DenseMap<Pass *, OperationFingerPrint> beforePassFingerPrints;
96 };
97 } // namespace
98 
printIR(Operation * op,bool printModuleScope,raw_ostream & out,OpPrintingFlags flags)99 static void printIR(Operation *op, bool printModuleScope, raw_ostream &out,
100                     OpPrintingFlags flags) {
101   // Otherwise, check to see if we are not printing at module scope.
102   if (!printModuleScope)
103     return op->print(out << " //----- //\n",
104                      op->getBlock() ? flags.useLocalScope() : flags);
105 
106   // Otherwise, we are printing at module scope.
107   out << " ('" << op->getName() << "' operation";
108   if (auto symbolName =
109           op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()))
110     out << ": @" << symbolName.getValue();
111   out << ") //----- //\n";
112 
113   // Find the top-level operation.
114   auto *topLevelOp = op;
115   while (auto *parentOp = topLevelOp->getParentOp())
116     topLevelOp = parentOp;
117   topLevelOp->print(out, flags);
118 }
119 
120 /// Instrumentation hooks.
runBeforePass(Pass * pass,Operation * op)121 void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) {
122   if (isa<OpToOpPassAdaptor>(pass))
123     return;
124   // If the config asked to detect changes, record the current fingerprint.
125   if (config->shouldPrintAfterOnlyOnChange())
126     beforePassFingerPrints.try_emplace(pass, op);
127 
128   config->printBeforeIfEnabled(pass, op, [&](raw_ostream &out) {
129     out << "// -----// IR Dump Before " << pass->getName() << " ("
130         << pass->getArgument() << ")";
131     printIR(op, config->shouldPrintAtModuleScope(), out,
132             config->getOpPrintingFlags());
133     out << "\n\n";
134   });
135 }
136 
runAfterPass(Pass * pass,Operation * op)137 void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) {
138   if (isa<OpToOpPassAdaptor>(pass))
139     return;
140 
141   // Check to see if we are only printing on failure.
142   if (config->shouldPrintAfterOnlyOnFailure())
143     return;
144 
145   // If the config asked to detect changes, compare the current fingerprint with
146   // the previous.
147   if (config->shouldPrintAfterOnlyOnChange()) {
148     auto fingerPrintIt = beforePassFingerPrints.find(pass);
149     assert(fingerPrintIt != beforePassFingerPrints.end() &&
150            "expected valid fingerprint");
151     // If the fingerprints are the same, we don't print the IR.
152     if (fingerPrintIt->second == OperationFingerPrint(op)) {
153       beforePassFingerPrints.erase(fingerPrintIt);
154       return;
155     }
156     beforePassFingerPrints.erase(fingerPrintIt);
157   }
158 
159   config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) {
160     out << "// -----// IR Dump After " << pass->getName() << " ("
161         << pass->getArgument() << ")";
162     printIR(op, config->shouldPrintAtModuleScope(), out,
163             config->getOpPrintingFlags());
164     out << "\n\n";
165   });
166 }
167 
runAfterPassFailed(Pass * pass,Operation * op)168 void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
169   if (isa<OpToOpPassAdaptor>(pass))
170     return;
171   if (config->shouldPrintAfterOnlyOnChange())
172     beforePassFingerPrints.erase(pass);
173 
174   config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) {
175     out << formatv("// -----// IR Dump After {0} Failed ({1})", pass->getName(),
176                    pass->getArgument());
177     printIR(op, config->shouldPrintAtModuleScope(), out, OpPrintingFlags());
178     out << "\n\n";
179   });
180 }
181 
182 //===----------------------------------------------------------------------===//
183 // IRPrinterConfig
184 //===----------------------------------------------------------------------===//
185 
186 /// Initialize the configuration.
IRPrinterConfig(bool printModuleScope,bool printAfterOnlyOnChange,bool printAfterOnlyOnFailure,OpPrintingFlags opPrintingFlags)187 PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope,
188                                               bool printAfterOnlyOnChange,
189                                               bool printAfterOnlyOnFailure,
190                                               OpPrintingFlags opPrintingFlags)
191     : printModuleScope(printModuleScope),
192       printAfterOnlyOnChange(printAfterOnlyOnChange),
193       printAfterOnlyOnFailure(printAfterOnlyOnFailure),
194       opPrintingFlags(opPrintingFlags) {}
195 PassManager::IRPrinterConfig::~IRPrinterConfig() = default;
196 
197 /// A hook that may be overridden by a derived config that checks if the IR
198 /// of 'operation' should be dumped *before* the pass 'pass' has been
199 /// executed. If the IR should be dumped, 'printCallback' should be invoked
200 /// with the stream to dump into.
printBeforeIfEnabled(Pass * pass,Operation * operation,PrintCallbackFn printCallback)201 void PassManager::IRPrinterConfig::printBeforeIfEnabled(
202     Pass *pass, Operation *operation, PrintCallbackFn printCallback) {
203   // By default, never print.
204 }
205 
206 /// A hook that may be overridden by a derived config that checks if the IR
207 /// of 'operation' should be dumped *after* the pass 'pass' has been
208 /// executed. If the IR should be dumped, 'printCallback' should be invoked
209 /// with the stream to dump into.
printAfterIfEnabled(Pass * pass,Operation * operation,PrintCallbackFn printCallback)210 void PassManager::IRPrinterConfig::printAfterIfEnabled(
211     Pass *pass, Operation *operation, PrintCallbackFn printCallback) {
212   // By default, never print.
213 }
214 
215 //===----------------------------------------------------------------------===//
216 // PassManager
217 //===----------------------------------------------------------------------===//
218 
219 namespace {
220 /// Simple wrapper config that allows for the simpler interface defined above.
221 struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig {
BasicIRPrinterConfig__anonc32b97650611::BasicIRPrinterConfig222   BasicIRPrinterConfig(
223       std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
224       std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
225       bool printModuleScope, bool printAfterOnlyOnChange,
226       bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags,
227       raw_ostream &out)
228       : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange,
229                         printAfterOnlyOnFailure, opPrintingFlags),
230         shouldPrintBeforePass(std::move(shouldPrintBeforePass)),
231         shouldPrintAfterPass(std::move(shouldPrintAfterPass)), out(out) {
232     assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) &&
233            "expected at least one valid filter function");
234   }
235 
printBeforeIfEnabled__anonc32b97650611::BasicIRPrinterConfig236   void printBeforeIfEnabled(Pass *pass, Operation *operation,
237                             PrintCallbackFn printCallback) final {
238     if (shouldPrintBeforePass && shouldPrintBeforePass(pass, operation))
239       printCallback(out);
240   }
241 
printAfterIfEnabled__anonc32b97650611::BasicIRPrinterConfig242   void printAfterIfEnabled(Pass *pass, Operation *operation,
243                            PrintCallbackFn printCallback) final {
244     if (shouldPrintAfterPass && shouldPrintAfterPass(pass, operation))
245       printCallback(out);
246   }
247 
248   /// Filter functions for before and after pass execution.
249   std::function<bool(Pass *, Operation *)> shouldPrintBeforePass;
250   std::function<bool(Pass *, Operation *)> shouldPrintAfterPass;
251 
252   /// The stream to output to.
253   raw_ostream &out;
254 };
255 } // namespace
256 
257 /// Add an instrumentation to print the IR before and after pass execution,
258 /// using the provided configuration.
enableIRPrinting(std::unique_ptr<IRPrinterConfig> config)259 void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) {
260   if (config->shouldPrintAtModuleScope() &&
261       getContext()->isMultithreadingEnabled())
262     llvm::report_fatal_error("IR printing can't be setup on a pass-manager "
263                              "without disabling multi-threading first.");
264   addInstrumentation(
265       std::make_unique<IRPrinterInstrumentation>(std::move(config)));
266 }
267 
268 /// Add an instrumentation to print the IR before and after pass execution.
enableIRPrinting(std::function<bool (Pass *,Operation *)> shouldPrintBeforePass,std::function<bool (Pass *,Operation *)> shouldPrintAfterPass,bool printModuleScope,bool printAfterOnlyOnChange,bool printAfterOnlyOnFailure,raw_ostream & out,OpPrintingFlags opPrintingFlags)269 void PassManager::enableIRPrinting(
270     std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
271     std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
272     bool printModuleScope, bool printAfterOnlyOnChange,
273     bool printAfterOnlyOnFailure, raw_ostream &out,
274     OpPrintingFlags opPrintingFlags) {
275   enableIRPrinting(std::make_unique<BasicIRPrinterConfig>(
276       std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass),
277       printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure,
278       opPrintingFlags, out));
279 }
280