1 //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements an applicator that applies pattern rewrites based upon a 10 // user defined cost model. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Rewrite/PatternApplicator.h" 15 #include "ByteCode.h" 16 #include "llvm/Support/Debug.h" 17 18 using namespace mlir; 19 using namespace mlir::detail; 20 21 PatternApplicator::PatternApplicator( 22 const FrozenRewritePatternList &frozenPatternList) 23 : frozenPatternList(frozenPatternList) { 24 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { 25 mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>(); 26 bytecode->initializeMutableState(*mutableByteCodeState); 27 } 28 } 29 PatternApplicator::~PatternApplicator() {} 30 31 #define DEBUG_TYPE "pattern-match" 32 33 void PatternApplicator::applyCostModel(CostModel model) { 34 // Apply the cost model to the bytecode patterns first, and then the native 35 // patterns. 36 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { 37 for (auto it : llvm::enumerate(bytecode->getPatterns())) 38 mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value())); 39 } 40 41 // Separate patterns by root kind to simplify lookup later on. 42 patterns.clear(); 43 anyOpPatterns.clear(); 44 for (const auto &pat : frozenPatternList.getNativePatterns()) { 45 // If the pattern is always impossible to match, just ignore it. 46 if (pat.getBenefit().isImpossibleToMatch()) { 47 LLVM_DEBUG({ 48 llvm::dbgs() 49 << "Ignoring pattern '" << pat.getRootKind() 50 << "' because it is impossible to match (by pattern benefit)\n"; 51 }); 52 continue; 53 } 54 if (Optional<OperationName> opName = pat.getRootKind()) 55 patterns[*opName].push_back(&pat); 56 else 57 anyOpPatterns.push_back(&pat); 58 } 59 60 // Sort the patterns using the provided cost model. 61 llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits; 62 auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) { 63 return benefits[lhs] > benefits[rhs]; 64 }; 65 auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) { 66 // Special case for one pattern in the list, which is the most common case. 67 if (list.size() == 1) { 68 if (model(*list.front()).isImpossibleToMatch()) { 69 LLVM_DEBUG({ 70 llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind() 71 << "' because it is impossible to match or cannot lead " 72 "to legal IR (by cost model)\n"; 73 }); 74 list.clear(); 75 } 76 return; 77 } 78 79 // Collect the dynamic benefits for the current pattern list. 80 benefits.clear(); 81 for (const Pattern *pat : list) 82 benefits.try_emplace(pat, model(*pat)); 83 84 // Sort patterns with highest benefit first, and remove those that are 85 // impossible to match. 86 std::stable_sort(list.begin(), list.end(), cmp); 87 while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) { 88 LLVM_DEBUG({ 89 llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind() 90 << "' because it is impossible to match or cannot lead to " 91 "legal IR (by cost model)\n"; 92 }); 93 list.pop_back(); 94 } 95 }; 96 for (auto &it : patterns) 97 processPatternList(it.second); 98 processPatternList(anyOpPatterns); 99 } 100 101 void PatternApplicator::walkAllPatterns( 102 function_ref<void(const Pattern &)> walk) { 103 for (const Pattern &it : frozenPatternList.getNativePatterns()) 104 walk(it); 105 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { 106 for (const Pattern &it : bytecode->getPatterns()) 107 walk(it); 108 } 109 } 110 111 LogicalResult PatternApplicator::matchAndRewrite( 112 Operation *op, PatternRewriter &rewriter, 113 function_ref<bool(const Pattern &)> canApply, 114 function_ref<void(const Pattern &)> onFailure, 115 function_ref<LogicalResult(const Pattern &)> onSuccess) { 116 // Before checking native patterns, first match against the bytecode. This 117 // won't automatically perform any rewrites so there is no need to worry about 118 // conflicts. 119 SmallVector<PDLByteCode::MatchResult, 4> pdlMatches; 120 const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode(); 121 if (bytecode) 122 bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState); 123 124 // Check to see if there are patterns matching this specific operation type. 125 MutableArrayRef<const RewritePattern *> opPatterns; 126 auto patternIt = patterns.find(op->getName()); 127 if (patternIt != patterns.end()) 128 opPatterns = patternIt->second; 129 130 // Process the patterns for that match the specific operation type, and any 131 // operation type in an interleaved fashion. 132 unsigned opIt = 0, opE = opPatterns.size(); 133 unsigned anyIt = 0, anyE = anyOpPatterns.size(); 134 unsigned pdlIt = 0, pdlE = pdlMatches.size(); 135 LogicalResult result = failure(); 136 do { 137 // Find the next pattern with the highest benefit. 138 const Pattern *bestPattern = nullptr; 139 unsigned *bestPatternIt = &opIt; 140 const PDLByteCode::MatchResult *pdlMatch = nullptr; 141 142 /// Operation specific patterns. 143 if (opIt < opE) 144 bestPattern = opPatterns[opIt]; 145 /// Operation agnostic patterns. 146 if (anyIt < anyE && 147 (!bestPattern || 148 bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) { 149 bestPatternIt = &anyIt; 150 bestPattern = anyOpPatterns[anyIt]; 151 } 152 /// PDL patterns. 153 if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() < 154 pdlMatches[pdlIt].benefit)) { 155 bestPatternIt = &pdlIt; 156 pdlMatch = &pdlMatches[pdlIt]; 157 bestPattern = pdlMatch->pattern; 158 } 159 if (!bestPattern) 160 break; 161 162 // Update the pattern iterator on failure so that this pattern isn't 163 // attempted again. 164 ++(*bestPatternIt); 165 166 // Check that the pattern can be applied. 167 if (canApply && !canApply(*bestPattern)) 168 continue; 169 170 // Try to match and rewrite this pattern. The patterns are sorted by 171 // benefit, so if we match we can immediately rewrite. For PDL patterns, the 172 // match has already been performed, we just need to rewrite. 173 rewriter.setInsertionPoint(op); 174 if (pdlMatch) { 175 bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState); 176 result = success(!onSuccess || succeeded(onSuccess(*bestPattern))); 177 178 } else { 179 const auto *pattern = static_cast<const RewritePattern *>(bestPattern); 180 result = pattern->matchAndRewrite(op, rewriter); 181 if (succeeded(result) && onSuccess && failed(onSuccess(*pattern))) 182 result = failure(); 183 } 184 if (succeeded(result)) 185 break; 186 187 // Perform any necessary cleanups. 188 if (onFailure) 189 onFailure(*bestPattern); 190 } while (true); 191 192 if (mutableByteCodeState) 193 mutableByteCodeState->cleanupAfterMatchAndRewrite(); 194 return result; 195 } 196