1 //===- IRPrinting.cpp -----------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/IR/SymbolTable.h" 11 #include "mlir/Pass/PassManager.h" 12 #include "llvm/Support/Format.h" 13 #include "llvm/Support/FormatVariadic.h" 14 #include "llvm/Support/SHA1.h" 15 16 using namespace mlir; 17 using namespace mlir::detail; 18 19 namespace { 20 //===----------------------------------------------------------------------===// 21 // OperationFingerPrint 22 //===----------------------------------------------------------------------===// 23 24 /// A unique fingerprint for a specific operation, and all of it's internal 25 /// operations. 26 class OperationFingerPrint { 27 public: 28 OperationFingerPrint(Operation *topOp) { 29 llvm::SHA1 hasher; 30 31 // Hash each of the operations based upon their mutable bits: 32 topOp->walk([&](Operation *op) { 33 // - Operation pointer 34 addDataToHash(hasher, op); 35 // - Attributes 36 addDataToHash(hasher, op->getAttrDictionary()); 37 // - Blocks in Regions 38 for (Region ®ion : op->getRegions()) { 39 for (Block &block : region) { 40 addDataToHash(hasher, &block); 41 for (BlockArgument arg : block.getArguments()) 42 addDataToHash(hasher, arg); 43 } 44 } 45 // - Location 46 addDataToHash(hasher, op->getLoc().getAsOpaquePointer()); 47 // - Operands 48 for (Value operand : op->getOperands()) 49 addDataToHash(hasher, operand); 50 // - Successors 51 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) 52 addDataToHash(hasher, op->getSuccessor(i)); 53 }); 54 hash = hasher.result(); 55 } 56 57 bool operator==(const OperationFingerPrint &other) const { 58 return hash == other.hash; 59 } 60 bool operator!=(const OperationFingerPrint &other) const { 61 return !(*this == other); 62 } 63 64 private: 65 template <typename T> void addDataToHash(llvm::SHA1 &hasher, const T &data) { 66 hasher.update( 67 ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T))); 68 } 69 70 std::array<uint8_t, 20> hash; 71 }; 72 73 //===----------------------------------------------------------------------===// 74 // IRPrinter 75 //===----------------------------------------------------------------------===// 76 77 class IRPrinterInstrumentation : public PassInstrumentation { 78 public: 79 IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config) 80 : config(std::move(config)) {} 81 82 private: 83 /// Instrumentation hooks. 84 void runBeforePass(Pass *pass, Operation *op) override; 85 void runAfterPass(Pass *pass, Operation *op) override; 86 void runAfterPassFailed(Pass *pass, Operation *op) override; 87 88 /// Configuration to use. 89 std::unique_ptr<PassManager::IRPrinterConfig> config; 90 91 /// The following is a set of fingerprints for operations that are currently 92 /// being operated on in a pass. This field is only used when the 93 /// configuration asked for change detection. 94 DenseMap<Pass *, OperationFingerPrint> beforePassFingerPrints; 95 }; 96 } // namespace 97 98 static void printIR(Operation *op, bool printModuleScope, raw_ostream &out, 99 OpPrintingFlags flags) { 100 // Otherwise, check to see if we are not printing at module scope. 101 if (!printModuleScope) 102 return op->print(out << " //----- //\n", 103 op->getBlock() ? flags.useLocalScope() : flags); 104 105 // Otherwise, we are printing at module scope. 106 out << " ('" << op->getName() << "' operation"; 107 if (auto symbolName = 108 op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())) 109 out << ": @" << symbolName.getValue(); 110 out << ") //----- //\n"; 111 112 // Find the top-level operation. 113 auto *topLevelOp = op; 114 while (auto *parentOp = topLevelOp->getParentOp()) 115 topLevelOp = parentOp; 116 topLevelOp->print(out, flags); 117 } 118 119 /// Instrumentation hooks. 120 void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) { 121 if (isa<OpToOpPassAdaptor>(pass)) 122 return; 123 // If the config asked to detect changes, record the current fingerprint. 124 if (config->shouldPrintAfterOnlyOnChange()) 125 beforePassFingerPrints.try_emplace(pass, op); 126 127 config->printBeforeIfEnabled(pass, op, [&](raw_ostream &out) { 128 out << "// -----// IR Dump Before " << pass->getName(); 129 printIR(op, config->shouldPrintAtModuleScope(), out, 130 config->getOpPrintingFlags()); 131 out << "\n\n"; 132 }); 133 } 134 135 void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) { 136 if (isa<OpToOpPassAdaptor>(pass)) 137 return; 138 139 // Check to see if we are only printing on failure. 140 if (config->shouldPrintAfterOnlyOnFailure()) 141 return; 142 143 // If the config asked to detect changes, compare the current fingerprint with 144 // the previous. 145 if (config->shouldPrintAfterOnlyOnChange()) { 146 auto fingerPrintIt = beforePassFingerPrints.find(pass); 147 assert(fingerPrintIt != beforePassFingerPrints.end() && 148 "expected valid fingerprint"); 149 // If the fingerprints are the same, we don't print the IR. 150 if (fingerPrintIt->second == OperationFingerPrint(op)) { 151 beforePassFingerPrints.erase(fingerPrintIt); 152 return; 153 } 154 beforePassFingerPrints.erase(fingerPrintIt); 155 } 156 157 config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) { 158 out << "// -----// IR Dump After " << pass->getName(); 159 printIR(op, config->shouldPrintAtModuleScope(), out, 160 config->getOpPrintingFlags()); 161 out << "\n\n"; 162 }); 163 } 164 165 void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) { 166 if (isa<OpToOpPassAdaptor>(pass)) 167 return; 168 if (config->shouldPrintAfterOnlyOnChange()) 169 beforePassFingerPrints.erase(pass); 170 171 config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) { 172 out << formatv("// -----// IR Dump After {0} Failed", pass->getName()); 173 printIR(op, config->shouldPrintAtModuleScope(), out, OpPrintingFlags()); 174 out << "\n\n"; 175 }); 176 } 177 178 //===----------------------------------------------------------------------===// 179 // IRPrinterConfig 180 //===----------------------------------------------------------------------===// 181 182 /// Initialize the configuration. 183 PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope, 184 bool printAfterOnlyOnChange, 185 bool printAfterOnlyOnFailure, 186 OpPrintingFlags opPrintingFlags) 187 : printModuleScope(printModuleScope), 188 printAfterOnlyOnChange(printAfterOnlyOnChange), 189 printAfterOnlyOnFailure(printAfterOnlyOnFailure), 190 opPrintingFlags(opPrintingFlags) {} 191 PassManager::IRPrinterConfig::~IRPrinterConfig() = default; 192 193 /// A hook that may be overridden by a derived config that checks if the IR 194 /// of 'operation' should be dumped *before* the pass 'pass' has been 195 /// executed. If the IR should be dumped, 'printCallback' should be invoked 196 /// with the stream to dump into. 197 void PassManager::IRPrinterConfig::printBeforeIfEnabled( 198 Pass *pass, Operation *operation, PrintCallbackFn printCallback) { 199 // By default, never print. 200 } 201 202 /// A hook that may be overridden by a derived config that checks if the IR 203 /// of 'operation' should be dumped *after* the pass 'pass' has been 204 /// executed. If the IR should be dumped, 'printCallback' should be invoked 205 /// with the stream to dump into. 206 void PassManager::IRPrinterConfig::printAfterIfEnabled( 207 Pass *pass, Operation *operation, PrintCallbackFn printCallback) { 208 // By default, never print. 209 } 210 211 //===----------------------------------------------------------------------===// 212 // PassManager 213 //===----------------------------------------------------------------------===// 214 215 namespace { 216 /// Simple wrapper config that allows for the simpler interface defined above. 217 struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig { 218 BasicIRPrinterConfig( 219 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, 220 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, 221 bool printModuleScope, bool printAfterOnlyOnChange, 222 bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags, 223 raw_ostream &out) 224 : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange, 225 printAfterOnlyOnFailure, opPrintingFlags), 226 shouldPrintBeforePass(std::move(shouldPrintBeforePass)), 227 shouldPrintAfterPass(std::move(shouldPrintAfterPass)), out(out) { 228 assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) && 229 "expected at least one valid filter function"); 230 } 231 232 void printBeforeIfEnabled(Pass *pass, Operation *operation, 233 PrintCallbackFn printCallback) final { 234 if (shouldPrintBeforePass && shouldPrintBeforePass(pass, operation)) 235 printCallback(out); 236 } 237 238 void printAfterIfEnabled(Pass *pass, Operation *operation, 239 PrintCallbackFn printCallback) final { 240 if (shouldPrintAfterPass && shouldPrintAfterPass(pass, operation)) 241 printCallback(out); 242 } 243 244 /// Filter functions for before and after pass execution. 245 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass; 246 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass; 247 248 /// The stream to output to. 249 raw_ostream &out; 250 }; 251 } // namespace 252 253 /// Add an instrumentation to print the IR before and after pass execution, 254 /// using the provided configuration. 255 void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) { 256 if (config->shouldPrintAtModuleScope() && 257 getContext()->isMultithreadingEnabled()) 258 llvm::report_fatal_error("IR printing can't be setup on a pass-manager " 259 "without disabling multi-threading first."); 260 addInstrumentation( 261 std::make_unique<IRPrinterInstrumentation>(std::move(config))); 262 } 263 264 /// Add an instrumentation to print the IR before and after pass execution. 265 void PassManager::enableIRPrinting( 266 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, 267 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, 268 bool printModuleScope, bool printAfterOnlyOnChange, 269 bool printAfterOnlyOnFailure, raw_ostream &out, 270 OpPrintingFlags opPrintingFlags) { 271 enableIRPrinting(std::make_unique<BasicIRPrinterConfig>( 272 std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass), 273 printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure, 274 opPrintingFlags, out)); 275 } 276