1 //===- IRPrinting.cpp -----------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/IR/SymbolTable.h" 11 #include "mlir/Pass/PassManager.h" 12 #include "llvm/Support/Format.h" 13 #include "llvm/Support/FormatVariadic.h" 14 #include "llvm/Support/SHA1.h" 15 16 using namespace mlir; 17 using namespace mlir::detail; 18 19 namespace { 20 //===----------------------------------------------------------------------===// 21 // OperationFingerPrint 22 //===----------------------------------------------------------------------===// 23 24 /// A unique fingerprint for a specific operation, and all of it's internal 25 /// operations. 26 class OperationFingerPrint { 27 public: 28 OperationFingerPrint(Operation *topOp) { 29 llvm::SHA1 hasher; 30 31 // Hash each of the operations based upon their mutable bits: 32 topOp->walk([&](Operation *op) { 33 // - Operation pointer 34 addDataToHash(hasher, op); 35 // - Attributes 36 addDataToHash(hasher, op->getAttrDictionary()); 37 // - Blocks in Regions 38 for (Region ®ion : op->getRegions()) { 39 for (Block &block : region) { 40 addDataToHash(hasher, &block); 41 for (BlockArgument arg : block.getArguments()) 42 addDataToHash(hasher, arg); 43 } 44 } 45 // - Location 46 addDataToHash(hasher, op->getLoc().getAsOpaquePointer()); 47 // - Operands 48 for (Value operand : op->getOperands()) 49 addDataToHash(hasher, operand); 50 // - Successors 51 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) 52 addDataToHash(hasher, op->getSuccessor(i)); 53 }); 54 hash = hasher.result(); 55 } 56 57 bool operator==(const OperationFingerPrint &other) const { 58 return hash == other.hash; 59 } 60 bool operator!=(const OperationFingerPrint &other) const { 61 return !(*this == other); 62 } 63 64 private: 65 template <typename T> 66 void addDataToHash(llvm::SHA1 &hasher, const T &data) { 67 hasher.update( 68 ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T))); 69 } 70 71 std::array<uint8_t, 20> hash; 72 }; 73 74 //===----------------------------------------------------------------------===// 75 // IRPrinter 76 //===----------------------------------------------------------------------===// 77 78 class IRPrinterInstrumentation : public PassInstrumentation { 79 public: 80 IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config) 81 : config(std::move(config)) {} 82 83 private: 84 /// Instrumentation hooks. 85 void runBeforePass(Pass *pass, Operation *op) override; 86 void runAfterPass(Pass *pass, Operation *op) override; 87 void runAfterPassFailed(Pass *pass, Operation *op) override; 88 89 /// Configuration to use. 90 std::unique_ptr<PassManager::IRPrinterConfig> config; 91 92 /// The following is a set of fingerprints for operations that are currently 93 /// being operated on in a pass. This field is only used when the 94 /// configuration asked for change detection. 95 DenseMap<Pass *, OperationFingerPrint> beforePassFingerPrints; 96 }; 97 } // namespace 98 99 static void printIR(Operation *op, bool printModuleScope, raw_ostream &out, 100 OpPrintingFlags flags) { 101 // Otherwise, check to see if we are not printing at module scope. 102 if (!printModuleScope) 103 return op->print(out << " //----- //\n", 104 op->getBlock() ? flags.useLocalScope() : flags); 105 106 // Otherwise, we are printing at module scope. 107 out << " ('" << op->getName() << "' operation"; 108 if (auto symbolName = 109 op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())) 110 out << ": @" << symbolName.getValue(); 111 out << ") //----- //\n"; 112 113 // Find the top-level operation. 114 auto *topLevelOp = op; 115 while (auto *parentOp = topLevelOp->getParentOp()) 116 topLevelOp = parentOp; 117 topLevelOp->print(out, flags); 118 } 119 120 /// Instrumentation hooks. 121 void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) { 122 if (isa<OpToOpPassAdaptor>(pass)) 123 return; 124 // If the config asked to detect changes, record the current fingerprint. 125 if (config->shouldPrintAfterOnlyOnChange()) 126 beforePassFingerPrints.try_emplace(pass, op); 127 128 config->printBeforeIfEnabled(pass, op, [&](raw_ostream &out) { 129 out << "// -----// IR Dump Before " << pass->getName() << " (" 130 << pass->getArgument() << ")"; 131 printIR(op, config->shouldPrintAtModuleScope(), out, 132 config->getOpPrintingFlags()); 133 out << "\n\n"; 134 }); 135 } 136 137 void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) { 138 if (isa<OpToOpPassAdaptor>(pass)) 139 return; 140 141 // Check to see if we are only printing on failure. 142 if (config->shouldPrintAfterOnlyOnFailure()) 143 return; 144 145 // If the config asked to detect changes, compare the current fingerprint with 146 // the previous. 147 if (config->shouldPrintAfterOnlyOnChange()) { 148 auto fingerPrintIt = beforePassFingerPrints.find(pass); 149 assert(fingerPrintIt != beforePassFingerPrints.end() && 150 "expected valid fingerprint"); 151 // If the fingerprints are the same, we don't print the IR. 152 if (fingerPrintIt->second == OperationFingerPrint(op)) { 153 beforePassFingerPrints.erase(fingerPrintIt); 154 return; 155 } 156 beforePassFingerPrints.erase(fingerPrintIt); 157 } 158 159 config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) { 160 out << "// -----// IR Dump After " << pass->getName() << " (" 161 << pass->getArgument() << ")"; 162 printIR(op, config->shouldPrintAtModuleScope(), out, 163 config->getOpPrintingFlags()); 164 out << "\n\n"; 165 }); 166 } 167 168 void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) { 169 if (isa<OpToOpPassAdaptor>(pass)) 170 return; 171 if (config->shouldPrintAfterOnlyOnChange()) 172 beforePassFingerPrints.erase(pass); 173 174 config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) { 175 out << formatv("// -----// IR Dump After {0} Failed ({1})", pass->getName(), 176 pass->getArgument()); 177 printIR(op, config->shouldPrintAtModuleScope(), out, OpPrintingFlags()); 178 out << "\n\n"; 179 }); 180 } 181 182 //===----------------------------------------------------------------------===// 183 // IRPrinterConfig 184 //===----------------------------------------------------------------------===// 185 186 /// Initialize the configuration. 187 PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope, 188 bool printAfterOnlyOnChange, 189 bool printAfterOnlyOnFailure, 190 OpPrintingFlags opPrintingFlags) 191 : printModuleScope(printModuleScope), 192 printAfterOnlyOnChange(printAfterOnlyOnChange), 193 printAfterOnlyOnFailure(printAfterOnlyOnFailure), 194 opPrintingFlags(opPrintingFlags) {} 195 PassManager::IRPrinterConfig::~IRPrinterConfig() = default; 196 197 /// A hook that may be overridden by a derived config that checks if the IR 198 /// of 'operation' should be dumped *before* the pass 'pass' has been 199 /// executed. If the IR should be dumped, 'printCallback' should be invoked 200 /// with the stream to dump into. 201 void PassManager::IRPrinterConfig::printBeforeIfEnabled( 202 Pass *pass, Operation *operation, PrintCallbackFn printCallback) { 203 // By default, never print. 204 } 205 206 /// A hook that may be overridden by a derived config that checks if the IR 207 /// of 'operation' should be dumped *after* the pass 'pass' has been 208 /// executed. If the IR should be dumped, 'printCallback' should be invoked 209 /// with the stream to dump into. 210 void PassManager::IRPrinterConfig::printAfterIfEnabled( 211 Pass *pass, Operation *operation, PrintCallbackFn printCallback) { 212 // By default, never print. 213 } 214 215 //===----------------------------------------------------------------------===// 216 // PassManager 217 //===----------------------------------------------------------------------===// 218 219 namespace { 220 /// Simple wrapper config that allows for the simpler interface defined above. 221 struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig { 222 BasicIRPrinterConfig( 223 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, 224 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, 225 bool printModuleScope, bool printAfterOnlyOnChange, 226 bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags, 227 raw_ostream &out) 228 : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange, 229 printAfterOnlyOnFailure, opPrintingFlags), 230 shouldPrintBeforePass(std::move(shouldPrintBeforePass)), 231 shouldPrintAfterPass(std::move(shouldPrintAfterPass)), out(out) { 232 assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) && 233 "expected at least one valid filter function"); 234 } 235 236 void printBeforeIfEnabled(Pass *pass, Operation *operation, 237 PrintCallbackFn printCallback) final { 238 if (shouldPrintBeforePass && shouldPrintBeforePass(pass, operation)) 239 printCallback(out); 240 } 241 242 void printAfterIfEnabled(Pass *pass, Operation *operation, 243 PrintCallbackFn printCallback) final { 244 if (shouldPrintAfterPass && shouldPrintAfterPass(pass, operation)) 245 printCallback(out); 246 } 247 248 /// Filter functions for before and after pass execution. 249 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass; 250 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass; 251 252 /// The stream to output to. 253 raw_ostream &out; 254 }; 255 } // namespace 256 257 /// Add an instrumentation to print the IR before and after pass execution, 258 /// using the provided configuration. 259 void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) { 260 if (config->shouldPrintAtModuleScope() && 261 getContext()->isMultithreadingEnabled()) 262 llvm::report_fatal_error("IR printing can't be setup on a pass-manager " 263 "without disabling multi-threading first."); 264 addInstrumentation( 265 std::make_unique<IRPrinterInstrumentation>(std::move(config))); 266 } 267 268 /// Add an instrumentation to print the IR before and after pass execution. 269 void PassManager::enableIRPrinting( 270 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, 271 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, 272 bool printModuleScope, bool printAfterOnlyOnChange, 273 bool printAfterOnlyOnFailure, raw_ostream &out, 274 OpPrintingFlags opPrintingFlags) { 275 enableIRPrinting(std::make_unique<BasicIRPrinterConfig>( 276 std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass), 277 printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure, 278 opPrintingFlags, out)); 279 } 280