1 //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements an applicator that applies pattern rewrites based upon a 10 // user defined cost model. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Rewrite/PatternApplicator.h" 15 #include "ByteCode.h" 16 #include "llvm/Support/Debug.h" 17 18 #define DEBUG_TYPE "pattern-match" 19 20 using namespace mlir; 21 using namespace mlir::detail; 22 23 PatternApplicator::PatternApplicator( 24 const FrozenRewritePatternSet &frozenPatternList) 25 : frozenPatternList(frozenPatternList) { 26 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { 27 mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>(); 28 bytecode->initializeMutableState(*mutableByteCodeState); 29 } 30 } 31 PatternApplicator::~PatternApplicator() {} 32 33 /// Log a message for a pattern that is impossible to match. 34 static void logImpossibleToMatch(const Pattern &pattern) { 35 LLVM_DEBUG({ 36 llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind() 37 << "' because it is impossible to match or cannot lead " 38 "to legal IR (by cost model)\n"; 39 }); 40 } 41 42 void PatternApplicator::applyCostModel(CostModel model) { 43 // Apply the cost model to the bytecode patterns first, and then the native 44 // patterns. 45 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { 46 for (auto it : llvm::enumerate(bytecode->getPatterns())) 47 mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value())); 48 } 49 50 // Copy over the patterns so that we can sort by benefit based on the cost 51 // model. Patterns that are already impossible to match are ignored. 52 patterns.clear(); 53 for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) { 54 for (const RewritePattern *pattern : it.second) { 55 if (pattern->getBenefit().isImpossibleToMatch()) 56 logImpossibleToMatch(*pattern); 57 else 58 patterns[it.first].push_back(pattern); 59 } 60 } 61 anyOpPatterns.clear(); 62 for (const RewritePattern &pattern : 63 frozenPatternList.getMatchAnyOpNativePatterns()) { 64 if (pattern.getBenefit().isImpossibleToMatch()) 65 logImpossibleToMatch(pattern); 66 else 67 anyOpPatterns.push_back(&pattern); 68 } 69 70 // Sort the patterns using the provided cost model. 71 llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits; 72 auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) { 73 return benefits[lhs] > benefits[rhs]; 74 }; 75 auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) { 76 // Special case for one pattern in the list, which is the most common case. 77 if (list.size() == 1) { 78 if (model(*list.front()).isImpossibleToMatch()) { 79 logImpossibleToMatch(*list.front()); 80 list.clear(); 81 } 82 return; 83 } 84 85 // Collect the dynamic benefits for the current pattern list. 86 benefits.clear(); 87 for (const Pattern *pat : list) 88 benefits.try_emplace(pat, model(*pat)); 89 90 // Sort patterns with highest benefit first, and remove those that are 91 // impossible to match. 92 std::stable_sort(list.begin(), list.end(), cmp); 93 while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) 94 logImpossibleToMatch(*list.pop_back_val()); 95 }; 96 for (auto &it : patterns) 97 processPatternList(it.second); 98 processPatternList(anyOpPatterns); 99 } 100 101 void PatternApplicator::walkAllPatterns( 102 function_ref<void(const Pattern &)> walk) { 103 for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) 104 for (const auto &pattern : it.second) 105 walk(*pattern); 106 for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns()) 107 walk(it); 108 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { 109 for (const Pattern &it : bytecode->getPatterns()) 110 walk(it); 111 } 112 } 113 114 LogicalResult PatternApplicator::matchAndRewrite( 115 Operation *op, PatternRewriter &rewriter, 116 function_ref<bool(const Pattern &)> canApply, 117 function_ref<void(const Pattern &)> onFailure, 118 function_ref<LogicalResult(const Pattern &)> onSuccess) { 119 // Before checking native patterns, first match against the bytecode. This 120 // won't automatically perform any rewrites so there is no need to worry about 121 // conflicts. 122 SmallVector<PDLByteCode::MatchResult, 4> pdlMatches; 123 const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode(); 124 if (bytecode) 125 bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState); 126 127 // Check to see if there are patterns matching this specific operation type. 128 MutableArrayRef<const RewritePattern *> opPatterns; 129 auto patternIt = patterns.find(op->getName()); 130 if (patternIt != patterns.end()) 131 opPatterns = patternIt->second; 132 133 // Process the patterns for that match the specific operation type, and any 134 // operation type in an interleaved fashion. 135 unsigned opIt = 0, opE = opPatterns.size(); 136 unsigned anyIt = 0, anyE = anyOpPatterns.size(); 137 unsigned pdlIt = 0, pdlE = pdlMatches.size(); 138 LogicalResult result = failure(); 139 do { 140 // Find the next pattern with the highest benefit. 141 const Pattern *bestPattern = nullptr; 142 unsigned *bestPatternIt = &opIt; 143 const PDLByteCode::MatchResult *pdlMatch = nullptr; 144 145 /// Operation specific patterns. 146 if (opIt < opE) 147 bestPattern = opPatterns[opIt]; 148 /// Operation agnostic patterns. 149 if (anyIt < anyE && 150 (!bestPattern || 151 bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) { 152 bestPatternIt = &anyIt; 153 bestPattern = anyOpPatterns[anyIt]; 154 } 155 /// PDL patterns. 156 if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() < 157 pdlMatches[pdlIt].benefit)) { 158 bestPatternIt = &pdlIt; 159 pdlMatch = &pdlMatches[pdlIt]; 160 bestPattern = pdlMatch->pattern; 161 } 162 if (!bestPattern) 163 break; 164 165 // Update the pattern iterator on failure so that this pattern isn't 166 // attempted again. 167 ++(*bestPatternIt); 168 169 // Check that the pattern can be applied. 170 if (canApply && !canApply(*bestPattern)) 171 continue; 172 173 // Try to match and rewrite this pattern. The patterns are sorted by 174 // benefit, so if we match we can immediately rewrite. For PDL patterns, the 175 // match has already been performed, we just need to rewrite. 176 rewriter.setInsertionPoint(op); 177 if (pdlMatch) { 178 bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState); 179 result = success(!onSuccess || succeeded(onSuccess(*bestPattern))); 180 181 } else { 182 const auto *pattern = static_cast<const RewritePattern *>(bestPattern); 183 result = pattern->matchAndRewrite(op, rewriter); 184 if (succeeded(result) && onSuccess && failed(onSuccess(*pattern))) 185 result = failure(); 186 } 187 if (succeeded(result)) 188 break; 189 190 // Perform any necessary cleanups. 191 if (onFailure) 192 onFailure(*bestPattern); 193 } while (true); 194 195 if (mutableByteCodeState) 196 mutableByteCodeState->cleanupAfterMatchAndRewrite(); 197 return result; 198 } 199