1 //===- IRPrinting.cpp -----------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/IR/SymbolTable.h" 11 #include "mlir/Pass/PassManager.h" 12 #include "llvm/Support/Format.h" 13 #include "llvm/Support/FormatVariadic.h" 14 #include "llvm/Support/SHA1.h" 15 16 using namespace mlir; 17 using namespace mlir::detail; 18 19 namespace { 20 //===----------------------------------------------------------------------===// 21 // OperationFingerPrint 22 //===----------------------------------------------------------------------===// 23 24 /// A unique fingerprint for a specific operation, and all of it's internal 25 /// operations. 26 class OperationFingerPrint { 27 public: 28 OperationFingerPrint(Operation *topOp) { 29 llvm::SHA1 hasher; 30 31 // Hash each of the operations based upon their mutable bits: 32 topOp->walk([&](Operation *op) { 33 // - Operation pointer 34 addDataToHash(hasher, op); 35 // - Attributes 36 addDataToHash(hasher, op->getAttrDictionary()); 37 // - Blocks in Regions 38 for (Region ®ion : op->getRegions()) { 39 for (Block &block : region) { 40 addDataToHash(hasher, &block); 41 for (BlockArgument arg : block.getArguments()) 42 addDataToHash(hasher, arg); 43 } 44 } 45 // - Location 46 addDataToHash(hasher, op->getLoc().getAsOpaquePointer()); 47 // - Operands 48 for (Value operand : op->getOperands()) 49 addDataToHash(hasher, operand); 50 // - Successors 51 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) 52 addDataToHash(hasher, op->getSuccessor(i)); 53 }); 54 hash = hasher.result(); 55 } 56 57 bool operator==(const OperationFingerPrint &other) const { 58 return hash == other.hash; 59 } 60 bool operator!=(const OperationFingerPrint &other) const { 61 return !(*this == other); 62 } 63 64 private: 65 template <typename T> 66 void addDataToHash(llvm::SHA1 &hasher, const T &data) { 67 hasher.update( 68 ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T))); 69 } 70 71 std::array<uint8_t, 20> hash; 72 }; 73 74 //===----------------------------------------------------------------------===// 75 // IRPrinter 76 //===----------------------------------------------------------------------===// 77 78 class IRPrinterInstrumentation : public PassInstrumentation { 79 public: 80 IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config) 81 : config(std::move(config)) {} 82 83 private: 84 /// Instrumentation hooks. 85 void runBeforePass(Pass *pass, Operation *op) override; 86 void runAfterPass(Pass *pass, Operation *op) override; 87 void runAfterPassFailed(Pass *pass, Operation *op) override; 88 89 /// Configuration to use. 90 std::unique_ptr<PassManager::IRPrinterConfig> config; 91 92 /// The following is a set of fingerprints for operations that are currently 93 /// being operated on in a pass. This field is only used when the 94 /// configuration asked for change detection. 95 DenseMap<Pass *, OperationFingerPrint> beforePassFingerPrints; 96 }; 97 } // namespace 98 99 static void printIR(Operation *op, bool printModuleScope, raw_ostream &out, 100 OpPrintingFlags flags) { 101 // Otherwise, check to see if we are not printing at module scope. 102 if (!printModuleScope) 103 return op->print(out << " //----- //\n", 104 op->getBlock() ? flags.useLocalScope() : flags); 105 106 // Otherwise, we are printing at module scope. 107 out << " ('" << op->getName() << "' operation"; 108 if (auto symbolName = 109 op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())) 110 out << ": @" << symbolName.getValue(); 111 out << ") //----- //\n"; 112 113 // Find the top-level operation. 114 auto *topLevelOp = op; 115 while (auto *parentOp = topLevelOp->getParentOp()) 116 topLevelOp = parentOp; 117 topLevelOp->print(out, flags); 118 } 119 120 /// Instrumentation hooks. 121 void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) { 122 if (isa<OpToOpPassAdaptor>(pass)) 123 return; 124 // If the config asked to detect changes, record the current fingerprint. 125 if (config->shouldPrintAfterOnlyOnChange()) 126 beforePassFingerPrints.try_emplace(pass, op); 127 128 config->printBeforeIfEnabled(pass, op, [&](raw_ostream &out) { 129 out << "// -----// IR Dump Before " << pass->getName(); 130 printIR(op, config->shouldPrintAtModuleScope(), out, 131 config->getOpPrintingFlags()); 132 out << "\n\n"; 133 }); 134 } 135 136 void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) { 137 if (isa<OpToOpPassAdaptor>(pass)) 138 return; 139 140 // Check to see if we are only printing on failure. 141 if (config->shouldPrintAfterOnlyOnFailure()) 142 return; 143 144 // If the config asked to detect changes, compare the current fingerprint with 145 // the previous. 146 if (config->shouldPrintAfterOnlyOnChange()) { 147 auto fingerPrintIt = beforePassFingerPrints.find(pass); 148 assert(fingerPrintIt != beforePassFingerPrints.end() && 149 "expected valid fingerprint"); 150 // If the fingerprints are the same, we don't print the IR. 151 if (fingerPrintIt->second == OperationFingerPrint(op)) { 152 beforePassFingerPrints.erase(fingerPrintIt); 153 return; 154 } 155 beforePassFingerPrints.erase(fingerPrintIt); 156 } 157 158 config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) { 159 out << "// -----// IR Dump After " << pass->getName(); 160 printIR(op, config->shouldPrintAtModuleScope(), out, 161 config->getOpPrintingFlags()); 162 out << "\n\n"; 163 }); 164 } 165 166 void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) { 167 if (isa<OpToOpPassAdaptor>(pass)) 168 return; 169 if (config->shouldPrintAfterOnlyOnChange()) 170 beforePassFingerPrints.erase(pass); 171 172 config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) { 173 out << formatv("// -----// IR Dump After {0} Failed", pass->getName()); 174 printIR(op, config->shouldPrintAtModuleScope(), out, OpPrintingFlags()); 175 out << "\n\n"; 176 }); 177 } 178 179 //===----------------------------------------------------------------------===// 180 // IRPrinterConfig 181 //===----------------------------------------------------------------------===// 182 183 /// Initialize the configuration. 184 PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope, 185 bool printAfterOnlyOnChange, 186 bool printAfterOnlyOnFailure, 187 OpPrintingFlags opPrintingFlags) 188 : printModuleScope(printModuleScope), 189 printAfterOnlyOnChange(printAfterOnlyOnChange), 190 printAfterOnlyOnFailure(printAfterOnlyOnFailure), 191 opPrintingFlags(opPrintingFlags) {} 192 PassManager::IRPrinterConfig::~IRPrinterConfig() = default; 193 194 /// A hook that may be overridden by a derived config that checks if the IR 195 /// of 'operation' should be dumped *before* the pass 'pass' has been 196 /// executed. If the IR should be dumped, 'printCallback' should be invoked 197 /// with the stream to dump into. 198 void PassManager::IRPrinterConfig::printBeforeIfEnabled( 199 Pass *pass, Operation *operation, PrintCallbackFn printCallback) { 200 // By default, never print. 201 } 202 203 /// A hook that may be overridden by a derived config that checks if the IR 204 /// of 'operation' should be dumped *after* the pass 'pass' has been 205 /// executed. If the IR should be dumped, 'printCallback' should be invoked 206 /// with the stream to dump into. 207 void PassManager::IRPrinterConfig::printAfterIfEnabled( 208 Pass *pass, Operation *operation, PrintCallbackFn printCallback) { 209 // By default, never print. 210 } 211 212 //===----------------------------------------------------------------------===// 213 // PassManager 214 //===----------------------------------------------------------------------===// 215 216 namespace { 217 /// Simple wrapper config that allows for the simpler interface defined above. 218 struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig { 219 BasicIRPrinterConfig( 220 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, 221 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, 222 bool printModuleScope, bool printAfterOnlyOnChange, 223 bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags, 224 raw_ostream &out) 225 : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange, 226 printAfterOnlyOnFailure, opPrintingFlags), 227 shouldPrintBeforePass(std::move(shouldPrintBeforePass)), 228 shouldPrintAfterPass(std::move(shouldPrintAfterPass)), out(out) { 229 assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) && 230 "expected at least one valid filter function"); 231 } 232 233 void printBeforeIfEnabled(Pass *pass, Operation *operation, 234 PrintCallbackFn printCallback) final { 235 if (shouldPrintBeforePass && shouldPrintBeforePass(pass, operation)) 236 printCallback(out); 237 } 238 239 void printAfterIfEnabled(Pass *pass, Operation *operation, 240 PrintCallbackFn printCallback) final { 241 if (shouldPrintAfterPass && shouldPrintAfterPass(pass, operation)) 242 printCallback(out); 243 } 244 245 /// Filter functions for before and after pass execution. 246 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass; 247 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass; 248 249 /// The stream to output to. 250 raw_ostream &out; 251 }; 252 } // namespace 253 254 /// Add an instrumentation to print the IR before and after pass execution, 255 /// using the provided configuration. 256 void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) { 257 if (config->shouldPrintAtModuleScope() && 258 getContext()->isMultithreadingEnabled()) 259 llvm::report_fatal_error("IR printing can't be setup on a pass-manager " 260 "without disabling multi-threading first."); 261 addInstrumentation( 262 std::make_unique<IRPrinterInstrumentation>(std::move(config))); 263 } 264 265 /// Add an instrumentation to print the IR before and after pass execution. 266 void PassManager::enableIRPrinting( 267 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, 268 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, 269 bool printModuleScope, bool printAfterOnlyOnChange, 270 bool printAfterOnlyOnFailure, raw_ostream &out, 271 OpPrintingFlags opPrintingFlags) { 272 enableIRPrinting(std::make_unique<BasicIRPrinterConfig>( 273 std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass), 274 printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure, 275 opPrintingFlags, out)); 276 } 277