1 //===- IRPrinting.cpp -----------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Pass/PassManager.h" 11 #include "llvm/Support/Format.h" 12 #include "llvm/Support/FormatVariadic.h" 13 #include "llvm/Support/SHA1.h" 14 15 using namespace mlir; 16 using namespace mlir::detail; 17 18 namespace { 19 //===----------------------------------------------------------------------===// 20 // OperationFingerPrint 21 //===----------------------------------------------------------------------===// 22 23 /// A unique fingerprint for a specific operation, and all of it's internal 24 /// operations. 25 class OperationFingerPrint { 26 public: 27 OperationFingerPrint(Operation *topOp) { 28 llvm::SHA1 hasher; 29 30 // Hash each of the operations based upon their mutable bits: 31 topOp->walk([&](Operation *op) { 32 // - Operation pointer 33 addDataToHash(hasher, op); 34 // - Attributes 35 addDataToHash(hasher, op->getAttrDictionary()); 36 // - Blocks in Regions 37 for (Region ®ion : op->getRegions()) { 38 for (Block &block : region) { 39 addDataToHash(hasher, &block); 40 for (BlockArgument arg : block.getArguments()) 41 addDataToHash(hasher, arg); 42 } 43 } 44 // - Location 45 addDataToHash(hasher, op->getLoc().getAsOpaquePointer()); 46 // - Operands 47 for (Value operand : op->getOperands()) 48 addDataToHash(hasher, operand); 49 // - Successors 50 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) 51 addDataToHash(hasher, op->getSuccessor(i)); 52 }); 53 hash = hasher.result(); 54 } 55 56 bool operator==(const OperationFingerPrint &other) const { 57 return hash == other.hash; 58 } 59 bool operator!=(const OperationFingerPrint &other) const { 60 return !(*this == other); 61 } 62 63 private: 64 template <typename T> void addDataToHash(llvm::SHA1 &hasher, const T &data) { 65 hasher.update( 66 ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T))); 67 } 68 69 SmallString<20> hash; 70 }; 71 72 //===----------------------------------------------------------------------===// 73 // IRPrinter 74 //===----------------------------------------------------------------------===// 75 76 class IRPrinterInstrumentation : public PassInstrumentation { 77 public: 78 IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config) 79 : config(std::move(config)) {} 80 81 private: 82 /// Instrumentation hooks. 83 void runBeforePass(Pass *pass, Operation *op) override; 84 void runAfterPass(Pass *pass, Operation *op) override; 85 void runAfterPassFailed(Pass *pass, Operation *op) override; 86 87 /// Configuration to use. 88 std::unique_ptr<PassManager::IRPrinterConfig> config; 89 90 /// The following is a set of fingerprints for operations that are currently 91 /// being operated on in a pass. This field is only used when the 92 /// configuration asked for change detection. 93 DenseMap<Pass *, OperationFingerPrint> beforePassFingerPrints; 94 }; 95 } // namespace 96 97 static void printIR(Operation *op, bool printModuleScope, raw_ostream &out, 98 OpPrintingFlags flags) { 99 // Otherwise, check to see if we are not printing at module scope. 100 if (!printModuleScope) 101 return op->print(out << " //----- //\n", 102 op->getBlock() ? flags.useLocalScope() : flags); 103 104 // Otherwise, we are printing at module scope. 105 out << " ('" << op->getName() << "' operation"; 106 if (auto symbolName = 107 op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())) 108 out << ": @" << symbolName.getValue(); 109 out << ") //----- //\n"; 110 111 // Find the top-level operation. 112 auto *topLevelOp = op; 113 while (auto *parentOp = topLevelOp->getParentOp()) 114 topLevelOp = parentOp; 115 topLevelOp->print(out, flags); 116 } 117 118 /// Instrumentation hooks. 119 void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) { 120 if (isa<OpToOpPassAdaptor>(pass)) 121 return; 122 // If the config asked to detect changes, record the current fingerprint. 123 if (config->shouldPrintAfterOnlyOnChange()) 124 beforePassFingerPrints.try_emplace(pass, op); 125 126 config->printBeforeIfEnabled(pass, op, [&](raw_ostream &out) { 127 out << "// -----// IR Dump Before " << pass->getName(); 128 printIR(op, config->shouldPrintAtModuleScope(), out, 129 config->getOpPrintingFlags()); 130 out << "\n\n"; 131 }); 132 } 133 134 void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) { 135 if (isa<OpToOpPassAdaptor>(pass)) 136 return; 137 138 // Check to see if we are only printing on failure. 139 if (config->shouldPrintAfterOnlyOnFailure()) 140 return; 141 142 // If the config asked to detect changes, compare the current fingerprint with 143 // the previous. 144 if (config->shouldPrintAfterOnlyOnChange()) { 145 auto fingerPrintIt = beforePassFingerPrints.find(pass); 146 assert(fingerPrintIt != beforePassFingerPrints.end() && 147 "expected valid fingerprint"); 148 // If the fingerprints are the same, we don't print the IR. 149 if (fingerPrintIt->second == OperationFingerPrint(op)) { 150 beforePassFingerPrints.erase(fingerPrintIt); 151 return; 152 } 153 beforePassFingerPrints.erase(fingerPrintIt); 154 } 155 156 config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) { 157 out << "// -----// IR Dump After " << pass->getName(); 158 printIR(op, config->shouldPrintAtModuleScope(), out, 159 config->getOpPrintingFlags()); 160 out << "\n\n"; 161 }); 162 } 163 164 void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) { 165 if (isa<OpToOpPassAdaptor>(pass)) 166 return; 167 if (config->shouldPrintAfterOnlyOnChange()) 168 beforePassFingerPrints.erase(pass); 169 170 config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) { 171 out << formatv("// -----// IR Dump After {0} Failed", pass->getName()); 172 printIR(op, config->shouldPrintAtModuleScope(), out, 173 OpPrintingFlags().printGenericOpForm()); 174 out << "\n\n"; 175 }); 176 } 177 178 //===----------------------------------------------------------------------===// 179 // IRPrinterConfig 180 //===----------------------------------------------------------------------===// 181 182 /// Initialize the configuration. 183 PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope, 184 bool printAfterOnlyOnChange, 185 bool printAfterOnlyOnFailure, 186 OpPrintingFlags opPrintingFlags) 187 : printModuleScope(printModuleScope), 188 printAfterOnlyOnChange(printAfterOnlyOnChange), 189 printAfterOnlyOnFailure(printAfterOnlyOnFailure), 190 opPrintingFlags(opPrintingFlags) {} 191 PassManager::IRPrinterConfig::~IRPrinterConfig() = default; 192 193 /// A hook that may be overridden by a derived config that checks if the IR 194 /// of 'operation' should be dumped *before* the pass 'pass' has been 195 /// executed. If the IR should be dumped, 'printCallback' should be invoked 196 /// with the stream to dump into. 197 void PassManager::IRPrinterConfig::printBeforeIfEnabled( 198 Pass *pass, Operation *operation, PrintCallbackFn printCallback) { 199 // By default, never print. 200 } 201 202 /// A hook that may be overridden by a derived config that checks if the IR 203 /// of 'operation' should be dumped *after* the pass 'pass' has been 204 /// executed. If the IR should be dumped, 'printCallback' should be invoked 205 /// with the stream to dump into. 206 void PassManager::IRPrinterConfig::printAfterIfEnabled( 207 Pass *pass, Operation *operation, PrintCallbackFn printCallback) { 208 // By default, never print. 209 } 210 211 //===----------------------------------------------------------------------===// 212 // PassManager 213 //===----------------------------------------------------------------------===// 214 215 namespace { 216 /// Simple wrapper config that allows for the simpler interface defined above. 217 struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig { 218 BasicIRPrinterConfig( 219 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, 220 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, 221 bool printModuleScope, bool printAfterOnlyOnChange, 222 bool printAfterOnlyOnFailure, OpPrintingFlags opPrintingFlags, 223 raw_ostream &out) 224 : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange, 225 printAfterOnlyOnFailure, opPrintingFlags), 226 shouldPrintBeforePass(std::move(shouldPrintBeforePass)), 227 shouldPrintAfterPass(std::move(shouldPrintAfterPass)), out(out) { 228 assert((this->shouldPrintBeforePass || this->shouldPrintAfterPass) && 229 "expected at least one valid filter function"); 230 } 231 232 void printBeforeIfEnabled(Pass *pass, Operation *operation, 233 PrintCallbackFn printCallback) final { 234 if (shouldPrintBeforePass && shouldPrintBeforePass(pass, operation)) 235 printCallback(out); 236 } 237 238 void printAfterIfEnabled(Pass *pass, Operation *operation, 239 PrintCallbackFn printCallback) final { 240 if (shouldPrintAfterPass && shouldPrintAfterPass(pass, operation)) 241 printCallback(out); 242 } 243 244 /// Filter functions for before and after pass execution. 245 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass; 246 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass; 247 248 /// The stream to output to. 249 raw_ostream &out; 250 }; 251 } // namespace 252 253 /// Add an instrumentation to print the IR before and after pass execution, 254 /// using the provided configuration. 255 void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) { 256 if (config->shouldPrintAtModuleScope() && 257 getContext()->isMultithreadingEnabled()) 258 llvm::report_fatal_error("IR printing can't be setup on a pass-manager " 259 "without disabling multi-threading first."); 260 addInstrumentation( 261 std::make_unique<IRPrinterInstrumentation>(std::move(config))); 262 } 263 264 /// Add an instrumentation to print the IR before and after pass execution. 265 void PassManager::enableIRPrinting( 266 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, 267 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, 268 bool printModuleScope, bool printAfterOnlyOnChange, 269 bool printAfterOnlyOnFailure, raw_ostream &out, 270 OpPrintingFlags opPrintingFlags) { 271 enableIRPrinting(std::make_unique<BasicIRPrinterConfig>( 272 std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass), 273 printModuleScope, printAfterOnlyOnChange, printAfterOnlyOnFailure, 274 opPrintingFlags, out)); 275 } 276