1 //===- IRPrinting.cpp -----------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "PassDetail.h" 10 #include "mlir/Pass/PassManager.h" 11 #include "llvm/Support/Format.h" 12 #include "llvm/Support/FormatVariadic.h" 13 #include "llvm/Support/SHA1.h" 14 15 using namespace mlir; 16 using namespace mlir::detail; 17 18 namespace { 19 //===----------------------------------------------------------------------===// 20 // OperationFingerPrint 21 //===----------------------------------------------------------------------===// 22 23 /// A unique fingerprint for a specific operation, and all of it's internal 24 /// operations. 25 class OperationFingerPrint { 26 public: 27 OperationFingerPrint(Operation *topOp) { 28 llvm::SHA1 hasher; 29 30 // Hash each of the operations based upon their mutable bits: 31 topOp->walk([&](Operation *op) { 32 // - Operation pointer 33 addDataToHash(hasher, op); 34 // - Attributes 35 addDataToHash(hasher, op->getAttrDictionary()); 36 // - Blocks in Regions 37 for (Region ®ion : op->getRegions()) { 38 for (Block &block : region) { 39 addDataToHash(hasher, &block); 40 for (BlockArgument arg : block.getArguments()) 41 addDataToHash(hasher, arg); 42 } 43 } 44 // - Location 45 addDataToHash(hasher, op->getLoc().getAsOpaquePointer()); 46 // - Operands 47 for (Value operand : op->getOperands()) 48 addDataToHash(hasher, operand); 49 // - Successors 50 for (unsigned i = 0, e = op->getNumSuccessors(); i != e; ++i) 51 addDataToHash(hasher, op->getSuccessor(i)); 52 }); 53 hash = hasher.result(); 54 } 55 56 bool operator==(const OperationFingerPrint &other) const { 57 return hash == other.hash; 58 } 59 bool operator!=(const OperationFingerPrint &other) const { 60 return !(*this == other); 61 } 62 63 private: 64 template <typename T> void addDataToHash(llvm::SHA1 &hasher, const T &data) { 65 hasher.update( 66 ArrayRef<uint8_t>(reinterpret_cast<const uint8_t *>(&data), sizeof(T))); 67 } 68 69 SmallString<20> hash; 70 }; 71 72 //===----------------------------------------------------------------------===// 73 // IRPrinter 74 //===----------------------------------------------------------------------===// 75 76 class IRPrinterInstrumentation : public PassInstrumentation { 77 public: 78 IRPrinterInstrumentation(std::unique_ptr<PassManager::IRPrinterConfig> config) 79 : config(std::move(config)) {} 80 81 private: 82 /// Instrumentation hooks. 83 void runBeforePass(Pass *pass, Operation *op) override; 84 void runAfterPass(Pass *pass, Operation *op) override; 85 void runAfterPassFailed(Pass *pass, Operation *op) override; 86 87 /// Configuration to use. 88 std::unique_ptr<PassManager::IRPrinterConfig> config; 89 90 /// The following is a set of fingerprints for operations that are currently 91 /// being operated on in a pass. This field is only used when the 92 /// configuration asked for change detection. 93 DenseMap<Pass *, OperationFingerPrint> beforePassFingerPrints; 94 }; 95 } // end anonymous namespace 96 97 static void printIR(Operation *op, bool printModuleScope, raw_ostream &out, 98 OpPrintingFlags flags) { 99 // Otherwise, check to see if we are not printing at module scope. 100 if (!printModuleScope) 101 return op->print(out << "\n", 102 op->getBlock() ? flags.useLocalScope() : flags); 103 104 // Otherwise, we are printing at module scope. 105 out << " ('" << op->getName() << "' operation"; 106 if (auto symbolName = 107 op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())) 108 out << ": @" << symbolName.getValue(); 109 out << ")\n"; 110 111 // Find the top-level operation. 112 auto *topLevelOp = op; 113 while (auto *parentOp = topLevelOp->getParentOp()) 114 topLevelOp = parentOp; 115 topLevelOp->print(out, flags); 116 } 117 118 /// Instrumentation hooks. 119 void IRPrinterInstrumentation::runBeforePass(Pass *pass, Operation *op) { 120 if (isa<OpToOpPassAdaptor>(pass)) 121 return; 122 // If the config asked to detect changes, record the current fingerprint. 123 if (config->shouldPrintAfterOnlyOnChange()) 124 beforePassFingerPrints.try_emplace(pass, op); 125 126 config->printBeforeIfEnabled(pass, op, [&](raw_ostream &out) { 127 out << formatv("// *** IR Dump Before {0} ***", pass->getName()); 128 printIR(op, config->shouldPrintAtModuleScope(), out, 129 config->getOpPrintingFlags()); 130 out << "\n\n"; 131 }); 132 } 133 134 void IRPrinterInstrumentation::runAfterPass(Pass *pass, Operation *op) { 135 if (isa<OpToOpPassAdaptor>(pass)) 136 return; 137 // If the config asked to detect changes, compare the current fingerprint with 138 // the previous. 139 if (config->shouldPrintAfterOnlyOnChange()) { 140 auto fingerPrintIt = beforePassFingerPrints.find(pass); 141 assert(fingerPrintIt != beforePassFingerPrints.end() && 142 "expected valid fingerprint"); 143 // If the fingerprints are the same, we don't print the IR. 144 if (fingerPrintIt->second == OperationFingerPrint(op)) { 145 beforePassFingerPrints.erase(fingerPrintIt); 146 return; 147 } 148 beforePassFingerPrints.erase(fingerPrintIt); 149 } 150 151 config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) { 152 out << formatv("// *** IR Dump After {0} ***", pass->getName()); 153 printIR(op, config->shouldPrintAtModuleScope(), out, 154 config->getOpPrintingFlags()); 155 out << "\n\n"; 156 }); 157 } 158 159 void IRPrinterInstrumentation::runAfterPassFailed(Pass *pass, Operation *op) { 160 if (isa<OpToOpPassAdaptor>(pass)) 161 return; 162 if (config->shouldPrintAfterOnlyOnChange()) 163 beforePassFingerPrints.erase(pass); 164 165 config->printAfterIfEnabled(pass, op, [&](raw_ostream &out) { 166 out << formatv("// *** IR Dump After {0} Failed ***", pass->getName()); 167 printIR(op, config->shouldPrintAtModuleScope(), out, 168 OpPrintingFlags().printGenericOpForm()); 169 out << "\n\n"; 170 }); 171 } 172 173 //===----------------------------------------------------------------------===// 174 // IRPrinterConfig 175 //===----------------------------------------------------------------------===// 176 177 /// Initialize the configuration. 178 PassManager::IRPrinterConfig::IRPrinterConfig(bool printModuleScope, 179 bool printAfterOnlyOnChange, 180 OpPrintingFlags opPrintingFlags) 181 : printModuleScope(printModuleScope), 182 printAfterOnlyOnChange(printAfterOnlyOnChange), 183 opPrintingFlags(opPrintingFlags) {} 184 PassManager::IRPrinterConfig::~IRPrinterConfig() {} 185 186 /// A hook that may be overridden by a derived config that checks if the IR 187 /// of 'operation' should be dumped *before* the pass 'pass' has been 188 /// executed. If the IR should be dumped, 'printCallback' should be invoked 189 /// with the stream to dump into. 190 void PassManager::IRPrinterConfig::printBeforeIfEnabled( 191 Pass *pass, Operation *operation, PrintCallbackFn printCallback) { 192 // By default, never print. 193 } 194 195 /// A hook that may be overridden by a derived config that checks if the IR 196 /// of 'operation' should be dumped *after* the pass 'pass' has been 197 /// executed. If the IR should be dumped, 'printCallback' should be invoked 198 /// with the stream to dump into. 199 void PassManager::IRPrinterConfig::printAfterIfEnabled( 200 Pass *pass, Operation *operation, PrintCallbackFn printCallback) { 201 // By default, never print. 202 } 203 204 //===----------------------------------------------------------------------===// 205 // PassManager 206 //===----------------------------------------------------------------------===// 207 208 namespace { 209 /// Simple wrapper config that allows for the simpler interface defined above. 210 struct BasicIRPrinterConfig : public PassManager::IRPrinterConfig { 211 BasicIRPrinterConfig( 212 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, 213 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, 214 bool printModuleScope, bool printAfterOnlyOnChange, 215 OpPrintingFlags opPrintingFlags, raw_ostream &out) 216 : IRPrinterConfig(printModuleScope, printAfterOnlyOnChange, 217 opPrintingFlags), 218 shouldPrintBeforePass(shouldPrintBeforePass), 219 shouldPrintAfterPass(shouldPrintAfterPass), out(out) { 220 assert((shouldPrintBeforePass || shouldPrintAfterPass) && 221 "expected at least one valid filter function"); 222 } 223 224 void printBeforeIfEnabled(Pass *pass, Operation *operation, 225 PrintCallbackFn printCallback) final { 226 if (shouldPrintBeforePass && shouldPrintBeforePass(pass, operation)) 227 printCallback(out); 228 } 229 230 void printAfterIfEnabled(Pass *pass, Operation *operation, 231 PrintCallbackFn printCallback) final { 232 if (shouldPrintAfterPass && shouldPrintAfterPass(pass, operation)) 233 printCallback(out); 234 } 235 236 /// Filter functions for before and after pass execution. 237 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass; 238 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass; 239 240 /// The stream to output to. 241 raw_ostream &out; 242 }; 243 } // end anonymous namespace 244 245 /// Add an instrumentation to print the IR before and after pass execution, 246 /// using the provided configuration. 247 void PassManager::enableIRPrinting(std::unique_ptr<IRPrinterConfig> config) { 248 if (config->shouldPrintAtModuleScope() && 249 getContext()->isMultithreadingEnabled()) 250 llvm::report_fatal_error("IR printing can't be setup on a pass-manager " 251 "without disabling multi-threading first."); 252 addInstrumentation( 253 std::make_unique<IRPrinterInstrumentation>(std::move(config))); 254 } 255 256 /// Add an instrumentation to print the IR before and after pass execution. 257 void PassManager::enableIRPrinting( 258 std::function<bool(Pass *, Operation *)> shouldPrintBeforePass, 259 std::function<bool(Pass *, Operation *)> shouldPrintAfterPass, 260 bool printModuleScope, bool printAfterOnlyOnChange, raw_ostream &out, 261 OpPrintingFlags opPrintingFlags) { 262 enableIRPrinting(std::make_unique<BasicIRPrinterConfig>( 263 std::move(shouldPrintBeforePass), std::move(shouldPrintAfterPass), 264 printModuleScope, printAfterOnlyOnChange, opPrintingFlags, out)); 265 } 266