1 //===- IRPrinting.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetail.h"
10 #include "mlir/IR/SymbolTable.h"
11 #include "mlir/Pass/PassManager.h"
12 #include "llvm/Support/Format.h"
13 #include "llvm/Support/FormatVariadic.h"
14 #include "llvm/Support/SHA1.h"
15 
16 using namespace mlir;
17 using namespace mlir::detail;
18 
19 namespace {
20 //===----------------------------------------------------------------------===//
21 // OperationFingerPrint
22 //===----------------------------------------------------------------------===//
23 
24 /// A unique fingerprint for a specific operation, and all of it's internal
25 /// operations.
26 class OperationFingerPrint {
27 public:
28   OperationFingerPrint(Operation *topOp) {
29     llvm::SHA1 hasher;
30 
31     // Hash each of the operations based upon their mutable bits:
32     topOp->walk([&](Operation *op) {
33       //   - Operation pointer
34       addDataToHash(hasher, op);
35       //   - Attributes
36       addDataToHash(hasher, op->getAttrDictionary());
37       //   - Blocks in Regions
38       for (Region &region : op->getRegions()) {
39         for (Block &block : region) {
40           addDataToHash(hasher, &block);
41           for (BlockArgument arg : block.getArguments())
42             addDataToHash(hasher, arg);
43         }
44       }
45       //   - Location
46       addDataToHash(hasher, op->getLoc().getAsOpaquePointer());
47       //   - Operands
48       for (Value operand : op->getOperands())
49         addDataToHash(hasher, operand);
50       //   - Successors
51       for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i)
52         addDataToHash(hasher, op->getSuccessor(i));
53     });
54     hash = hasher.result();
55   }
56 
57   bool operator==(const OperationFingerPrint &other) const {
58     return hash == other.hash;
59   }
60   bool operator!=(const OperationFingerPrint &other) const {
61     return !(*this == other);
62   }
63 
64 private:
65   template <typename T> void addDataToHash(llvm::SHA1 &hasher, const T &data) {
66     hasher.update(
67         ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T)));
68   }
69 
70   std::array<uint8_t, 20> hash;
71 };
72 
73 //===----------------------------------------------------------------------===//
74 // IRPrinter
75 //===----------------------------------------------------------------------===//
76 
77 class IRPrinterInstrumentation : public PassInstrumentation {
78 public:
79   IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config)
80       : config(std::move(config)) {}
81 
82 private:
83   /// Instrumentation hooks.
84   void runBeforePass(Pass *pass, Operation *op) override;
85   void runAfterPass(Pass *pass, Operation *op) override;
86   void runAfterPassFailed(Pass *pass, Operation *op) override;
87 
88   /// Configuration to use.
89   std::unique_ptr<PassManager::IRPrinterConfig> config;
90 
91   /// The following is a set of fingerprints for operations that are currently
92   /// being operated on in a pass. This field is only used when the
93   /// configuration asked for change detection.
94   DenseMap<Pass *, OperationFingerPrint> beforePassFingerPrints;
95 };
96 } // namespace
97 
98 static void printIR(Operation *op, bool printModuleScope, raw_ostream &out,
99                     OpPrintingFlags flags) {
100   // Otherwise, check to see if we are not printing at module scope.
101   if (!printModuleScope)
102     return op->print(out << " //----- //\n",
103                      op->getBlock() ? flags.useLocalScope() : flags);
104 
105   // Otherwise, we are printing at module scope.
106   out << " ('" << op->getName() << "' operation";
107   if (auto symbolName =
108           op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()))
109     out << ": @" << symbolName.getValue();
110   out << ") //----- //\n";
111 
112   // Find the top-level operation.
113   auto *topLevelOp = op;
114   while (auto *parentOp = topLevelOp->getParentOp())
115     topLevelOp = parentOp;
116   topLevelOp->print(out, flags);
117 }
118 
119 /// Instrumentation hooks.
120 void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) {
121   if (isa<OpToOpPassAdaptor>(pass))
122     return;
123   // If the config asked to detect changes, record the current fingerprint.
124   if (config->shouldPrintAfterOnlyOnChange())
125     beforePassFingerPrints.try_emplace(pass, op);
126 
127   config->printBeforeIfEnabled(pass, op, [&](raw_ostream &out) {
128     out << "// -----// IR Dump Before " << pass->getName();
129     printIR(op, config->shouldPrintAtModuleScope(), out,
130             config->getOpPrintingFlags());
131     out << "\n\n";
132   });
133 }
134 
135 void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) {
136   if (isa<OpToOpPassAdaptor>(pass))
137     return;
138 
139   // Check to see if we are only printing on failure.
140   if (config->shouldPrintAfterOnlyOnFailure())
141     return;
142 
143   // If the config asked to detect changes, compare the current fingerprint with
144   // the previous.
145   if (config->shouldPrintAfterOnlyOnChange()) {
146     auto fingerPrintIt = beforePassFingerPrints.find(pass);
147     assert(fingerPrintIt != beforePassFingerPrints.end() &&
148            "expected valid fingerprint");
149     // If the fingerprints are the same, we don't print the IR.
150     if (fingerPrintIt->second == OperationFingerPrint(op)) {
151       beforePassFingerPrints.erase(fingerPrintIt);
152       return;
153     }
154     beforePassFingerPrints.erase(fingerPrintIt);
155   }
156 
157   config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) {
158     out << "// -----// IR Dump After " << pass->getName();
159     printIR(op, config->shouldPrintAtModuleScope(), out,
160             config->getOpPrintingFlags());
161     out << "\n\n";
162   });
163 }
164 
165 void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) {
166   if (isa<OpToOpPassAdaptor>(pass))
167     return;
168   if (config->shouldPrintAfterOnlyOnChange())
169     beforePassFingerPrints.erase(pass);
170 
171   config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) {
172     out << formatv("// -----// IR Dump After {0} Failed", pass->getName());
173     printIR(op, config->shouldPrintAtModuleScope(), out, OpPrintingFlags());
174     out << "\n\n";
175   });
176 }
177 
178 //===----------------------------------------------------------------------===//
179 // IRPrinterConfig
180 //===----------------------------------------------------------------------===//
181 
182 /// Initialize the configuration.
183 PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope,
184                                               bool printAfterOnlyOnChange,
185                                               bool printAfterOnlyOnFailure,
186                                               OpPrintingFlags opPrintingFlags)
187     : printModuleScope(printModuleScope),
188       printAfterOnlyOnChange(printAfterOnlyOnChange),
189       printAfterOnlyOnFailure(printAfterOnlyOnFailure),
190       opPrintingFlags(opPrintingFlags) {}
191 PassManager::IRPrinterConfig::~IRPrinterConfig() = default;
192 
193 /// A hook that may be overridden by a derived config that checks if the IR
194 /// of 'operation' should be dumped *before* the pass 'pass' has been
195 /// executed. If the IR should be dumped, 'printCallback' should be invoked
196 /// with the stream to dump into.
197 void PassManager::IRPrinterConfig::printBeforeIfEnabled(
198     Pass *pass, Operation *operation, PrintCallbackFn printCallback) {
199   // By default, never print.
200 }
201 
202 /// A hook that may be overridden by a derived config that checks if the IR
203 /// of 'operation' should be dumped *after* the pass 'pass' has been
204 /// executed. If the IR should be dumped, 'printCallback' should be invoked
205 /// with the stream to dump into.
206 void PassManager::IRPrinterConfig::printAfterIfEnabled(
207     Pass *pass, Operation *operation, PrintCallbackFn printCallback) {
208   // By default, never print.
209 }
210 
211 //===----------------------------------------------------------------------===//
212 // PassManager
213 //===----------------------------------------------------------------------===//
214 
215 namespace {
216 /// Simple wrapper config that allows for the simpler interface defined above.
217 struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig {
218   BasicIRPrinterConfig(
219       std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
220       std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
221       bool printModuleScope, bool printAfterOnlyOnChange,
222       bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags,
223       raw_ostream &out)
224       : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange,
225                         printAfterOnlyOnFailure, opPrintingFlags),
226         shouldPrintBeforePass(std::move(shouldPrintBeforePass)),
227         shouldPrintAfterPass(std::move(shouldPrintAfterPass)), out(out) {
228     assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) &&
229            "expected at least one valid filter function");
230   }
231 
232   void printBeforeIfEnabled(Pass *pass, Operation *operation,
233                             PrintCallbackFn printCallback) final {
234     if (shouldPrintBeforePass && shouldPrintBeforePass(pass, operation))
235       printCallback(out);
236   }
237 
238   void printAfterIfEnabled(Pass *pass, Operation *operation,
239                            PrintCallbackFn printCallback) final {
240     if (shouldPrintAfterPass && shouldPrintAfterPass(pass, operation))
241       printCallback(out);
242   }
243 
244   /// Filter functions for before and after pass execution.
245   std::function<bool(Pass *, Operation *)> shouldPrintBeforePass;
246   std::function<bool(Pass *, Operation *)> shouldPrintAfterPass;
247 
248   /// The stream to output to.
249   raw_ostream &out;
250 };
251 } // namespace
252 
253 /// Add an instrumentation to print the IR before and after pass execution,
254 /// using the provided configuration.
255 void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) {
256   if (config->shouldPrintAtModuleScope() &&
257       getContext()->isMultithreadingEnabled())
258     llvm::report_fatal_error("IR printing can't be setup on a pass-manager "
259                              "without disabling multi-threading first.");
260   addInstrumentation(
261       std::make_unique<IRPrinterInstrumentation>(std::move(config)));
262 }
263 
264 /// Add an instrumentation to print the IR before and after pass execution.
265 void PassManager::enableIRPrinting(
266     std::function<bool(Pass *, Operation *)> shouldPrintBeforePass,
267     std::function<bool(Pass *, Operation *)> shouldPrintAfterPass,
268     bool printModuleScope, bool printAfterOnlyOnChange,
269     bool printAfterOnlyOnFailure, raw_ostream &out,
270     OpPrintingFlags opPrintingFlags) {
271   enableIRPrinting(std::make_unique<BasicIRPrinterConfig>(
272       std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass),
273       printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure,
274       opPrintingFlags, out));
275 }
276