1 //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements an applicator that applies pattern rewrites based upon a 10 // user defined cost model. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Rewrite/PatternApplicator.h" 15 #include "ByteCode.h" 16 #include "llvm/Support/Debug.h" 17 18 #define DEBUG_TYPE "pattern-application" 19 20 using namespace mlir; 21 using namespace mlir::detail; 22 23 PatternApplicator::PatternApplicator( 24 const FrozenRewritePatternSet &frozenPatternList) 25 : frozenPatternList(frozenPatternList) { 26 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { 27 mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>(); 28 bytecode->initializeMutableState(*mutableByteCodeState); 29 } 30 } 31 PatternApplicator::~PatternApplicator() = default; 32 33 #ifndef NDEBUG 34 /// Log a message for a pattern that is impossible to match. 35 static void logImpossibleToMatch(const Pattern &pattern) { 36 llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind() 37 << "' because it is impossible to match or cannot lead " 38 "to legal IR (by cost model)\n"; 39 } 40 41 /// Log IR after pattern application. 42 static Operation *getDumpRootOp(Operation *op) { 43 return op->getParentWithTrait<mlir::OpTrait::IsIsolatedFromAbove>(); 44 } 45 static void logSucessfulPatternApplication(Operation *op) { 46 llvm::dbgs() << "// *** IR Dump After Pattern Application ***\n"; 47 op->dump(); 48 llvm::dbgs() << "\n\n"; 49 } 50 #endif 51 52 void PatternApplicator::applyCostModel(CostModel model) { 53 // Apply the cost model to the bytecode patterns first, and then the native 54 // patterns. 55 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { 56 for (auto it : llvm::enumerate(bytecode->getPatterns())) 57 mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value())); 58 } 59 60 // Copy over the patterns so that we can sort by benefit based on the cost 61 // model. Patterns that are already impossible to match are ignored. 62 patterns.clear(); 63 for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) { 64 for (const RewritePattern *pattern : it.second) { 65 if (pattern->getBenefit().isImpossibleToMatch()) 66 LLVM_DEBUG(logImpossibleToMatch(*pattern)); 67 else 68 patterns[it.first].push_back(pattern); 69 } 70 } 71 anyOpPatterns.clear(); 72 for (const RewritePattern &pattern : 73 frozenPatternList.getMatchAnyOpNativePatterns()) { 74 if (pattern.getBenefit().isImpossibleToMatch()) 75 LLVM_DEBUG(logImpossibleToMatch(pattern)); 76 else 77 anyOpPatterns.push_back(&pattern); 78 } 79 80 // Sort the patterns using the provided cost model. 81 llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits; 82 auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) { 83 return benefits[lhs] > benefits[rhs]; 84 }; 85 auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) { 86 // Special case for one pattern in the list, which is the most common case. 87 if (list.size() == 1) { 88 if (model(*list.front()).isImpossibleToMatch()) { 89 LLVM_DEBUG(logImpossibleToMatch(*list.front())); 90 list.clear(); 91 } 92 return; 93 } 94 95 // Collect the dynamic benefits for the current pattern list. 96 benefits.clear(); 97 for (const Pattern *pat : list) 98 benefits.try_emplace(pat, model(*pat)); 99 100 // Sort patterns with highest benefit first, and remove those that are 101 // impossible to match. 102 std::stable_sort(list.begin(), list.end(), cmp); 103 while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) { 104 LLVM_DEBUG(logImpossibleToMatch(*list.back())); 105 list.pop_back(); 106 } 107 }; 108 for (auto &it : patterns) 109 processPatternList(it.second); 110 processPatternList(anyOpPatterns); 111 } 112 113 void PatternApplicator::walkAllPatterns( 114 function_ref<void(const Pattern &)> walk) { 115 for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) 116 for (const auto &pattern : it.second) 117 walk(*pattern); 118 for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns()) 119 walk(it); 120 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { 121 for (const Pattern &it : bytecode->getPatterns()) 122 walk(it); 123 } 124 } 125 126 LogicalResult PatternApplicator::matchAndRewrite( 127 Operation *op, PatternRewriter &rewriter, 128 function_ref<bool(const Pattern &)> canApply, 129 function_ref<void(const Pattern &)> onFailure, 130 function_ref<LogicalResult(const Pattern &)> onSuccess) { 131 // Before checking native patterns, first match against the bytecode. This 132 // won't automatically perform any rewrites so there is no need to worry about 133 // conflicts. 134 SmallVector<PDLByteCode::MatchResult, 4> pdlMatches; 135 const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode(); 136 if (bytecode) 137 bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState); 138 139 // Check to see if there are patterns matching this specific operation type. 140 MutableArrayRef<const RewritePattern *> opPatterns; 141 auto patternIt = patterns.find(op->getName()); 142 if (patternIt != patterns.end()) 143 opPatterns = patternIt->second; 144 145 // Process the patterns for that match the specific operation type, and any 146 // operation type in an interleaved fashion. 147 unsigned opIt = 0, opE = opPatterns.size(); 148 unsigned anyIt = 0, anyE = anyOpPatterns.size(); 149 unsigned pdlIt = 0, pdlE = pdlMatches.size(); 150 LogicalResult result = failure(); 151 do { 152 // Find the next pattern with the highest benefit. 153 const Pattern *bestPattern = nullptr; 154 unsigned *bestPatternIt = &opIt; 155 const PDLByteCode::MatchResult *pdlMatch = nullptr; 156 157 /// Operation specific patterns. 158 if (opIt < opE) 159 bestPattern = opPatterns[opIt]; 160 /// Operation agnostic patterns. 161 if (anyIt < anyE && 162 (!bestPattern || 163 bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) { 164 bestPatternIt = &anyIt; 165 bestPattern = anyOpPatterns[anyIt]; 166 } 167 /// PDL patterns. 168 if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() < 169 pdlMatches[pdlIt].benefit)) { 170 bestPatternIt = &pdlIt; 171 pdlMatch = &pdlMatches[pdlIt]; 172 bestPattern = pdlMatch->pattern; 173 } 174 if (!bestPattern) 175 break; 176 177 // Update the pattern iterator on failure so that this pattern isn't 178 // attempted again. 179 ++(*bestPatternIt); 180 181 // Check that the pattern can be applied. 182 if (canApply && !canApply(*bestPattern)) 183 continue; 184 185 // Try to match and rewrite this pattern. The patterns are sorted by 186 // benefit, so if we match we can immediately rewrite. For PDL patterns, the 187 // match has already been performed, we just need to rewrite. 188 rewriter.setInsertionPoint(op); 189 #ifndef NDEBUG 190 // Operation `op` may be invalidated after applying the rewrite pattern. 191 Operation *dumpRootOp = getDumpRootOp(op); 192 #endif 193 if (pdlMatch) { 194 bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState); 195 result = success(!onSuccess || succeeded(onSuccess(*bestPattern))); 196 } else { 197 const auto *pattern = static_cast<const RewritePattern *>(bestPattern); 198 199 LLVM_DEBUG(llvm::dbgs() 200 << "Trying to match \"" << pattern->getDebugName() << "\"\n"); 201 result = pattern->matchAndRewrite(op, rewriter); 202 LLVM_DEBUG(llvm::dbgs() << "\"" << pattern->getDebugName() << "\" result " 203 << succeeded(result) << "\n"); 204 205 if (succeeded(result) && onSuccess && failed(onSuccess(*pattern))) 206 result = failure(); 207 } 208 if (succeeded(result)) { 209 LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp)); 210 break; 211 } 212 213 // Perform any necessary cleanups. 214 if (onFailure) 215 onFailure(*bestPattern); 216 } while (true); 217 218 if (mutableByteCodeState) 219 mutableByteCodeState->cleanupAfterMatchAndRewrite(); 220 return result; 221 } 222