1 //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements an applicator that applies pattern rewrites based upon a 10 // user defined cost model. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Rewrite/PatternApplicator.h" 15 #include "ByteCode.h" 16 #include "llvm/Support/Debug.h" 17 18 using namespace mlir; 19 using namespace mlir::detail; 20 21 PatternApplicator::PatternApplicator( 22 const FrozenRewritePatternList &frozenPatternList) 23 : frozenPatternList(frozenPatternList) { 24 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { 25 mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>(); 26 bytecode->initializeMutableState(*mutableByteCodeState); 27 } 28 } 29 PatternApplicator::~PatternApplicator() {} 30 31 #define DEBUG_TYPE "pattern-match" 32 33 void PatternApplicator::applyCostModel(CostModel model) { 34 // Apply the cost model to the bytecode patterns first, and then the native 35 // patterns. 36 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { 37 for (auto it : llvm::enumerate(bytecode->getPatterns())) 38 mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value())); 39 } 40 41 // Separate patterns by root kind to simplify lookup later on. 42 patterns.clear(); 43 anyOpPatterns.clear(); 44 for (const auto &pat : frozenPatternList.getNativePatterns()) { 45 // If the pattern is always impossible to match, just ignore it. 46 if (pat.getBenefit().isImpossibleToMatch()) { 47 LLVM_DEBUG({ 48 llvm::dbgs() 49 << "Ignoring pattern '" << pat.getRootKind() 50 << "' because it is impossible to match (by pattern benefit)\n"; 51 }); 52 continue; 53 } 54 if (Optional<OperationName> opName = pat.getRootKind()) 55 patterns[*opName].push_back(&pat); 56 else 57 anyOpPatterns.push_back(&pat); 58 } 59 60 // Sort the patterns using the provided cost model. 61 llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits; 62 auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) { 63 return benefits[lhs] > benefits[rhs]; 64 }; 65 auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) { 66 // Special case for one pattern in the list, which is the most common case. 67 if (list.size() == 1) { 68 if (model(*list.front()).isImpossibleToMatch()) { 69 LLVM_DEBUG({ 70 llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind() 71 << "' because it is impossible to match or cannot lead " 72 "to legal IR (by cost model)\n"; 73 }); 74 list.clear(); 75 } 76 return; 77 } 78 79 // Collect the dynamic benefits for the current pattern list. 80 benefits.clear(); 81 for (const Pattern *pat : list) 82 benefits.try_emplace(pat, model(*pat)); 83 84 // Sort patterns with highest benefit first, and remove those that are 85 // impossible to match. 86 std::stable_sort(list.begin(), list.end(), cmp); 87 while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) { 88 LLVM_DEBUG({ 89 llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind() 90 << "' because it is impossible to match or cannot lead to " 91 "legal IR (by cost model)\n"; 92 }); 93 list.pop_back(); 94 } 95 }; 96 for (auto &it : patterns) 97 processPatternList(it.second); 98 processPatternList(anyOpPatterns); 99 } 100 101 void PatternApplicator::walkAllPatterns( 102 function_ref<void(const Pattern &)> walk) { 103 for (const Pattern &it : frozenPatternList.getNativePatterns()) 104 walk(it); 105 if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) { 106 for (const Pattern &it : bytecode->getPatterns()) 107 walk(it); 108 } 109 } 110 111 LogicalResult PatternApplicator::matchAndRewrite( 112 Operation *op, PatternRewriter &rewriter, 113 function_ref<bool(const Pattern &)> canApply, 114 function_ref<void(const Pattern &)> onFailure, 115 function_ref<LogicalResult(const Pattern &)> onSuccess) { 116 // Before checking native patterns, first match against the bytecode. This 117 // won't automatically perform any rewrites so there is no need to worry about 118 // conflicts. 119 SmallVector<PDLByteCode::MatchResult, 4> pdlMatches; 120 const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode(); 121 if (bytecode) 122 bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState); 123 124 // Check to see if there are patterns matching this specific operation type. 125 MutableArrayRef<const RewritePattern *> opPatterns; 126 auto patternIt = patterns.find(op->getName()); 127 if (patternIt != patterns.end()) 128 opPatterns = patternIt->second; 129 130 // Process the patterns for that match the specific operation type, and any 131 // operation type in an interleaved fashion. 132 auto opIt = opPatterns.begin(), opE = opPatterns.end(); 133 auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end(); 134 auto pdlIt = pdlMatches.begin(), pdlE = pdlMatches.end(); 135 while (true) { 136 // Find the next pattern with the highest benefit. 137 const Pattern *bestPattern = nullptr; 138 const PDLByteCode::MatchResult *pdlMatch = nullptr; 139 /// Operation specific patterns. 140 if (opIt != opE) 141 bestPattern = *(opIt++); 142 /// Operation agnostic patterns. 143 if (anyIt != anyE && 144 (!bestPattern || bestPattern->getBenefit() < (*anyIt)->getBenefit())) 145 bestPattern = *(anyIt++); 146 /// PDL patterns. 147 if (pdlIt != pdlE && 148 (!bestPattern || bestPattern->getBenefit() < pdlIt->benefit)) { 149 pdlMatch = pdlIt; 150 bestPattern = (pdlIt++)->pattern; 151 } 152 if (!bestPattern) 153 break; 154 155 // Check that the pattern can be applied. 156 if (canApply && !canApply(*bestPattern)) 157 continue; 158 159 // Try to match and rewrite this pattern. The patterns are sorted by 160 // benefit, so if we match we can immediately rewrite. For PDL patterns, the 161 // match has already been performed, we just need to rewrite. 162 rewriter.setInsertionPoint(op); 163 LogicalResult result = success(); 164 if (pdlMatch) { 165 bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState); 166 } else { 167 result = static_cast<const RewritePattern *>(bestPattern) 168 ->matchAndRewrite(op, rewriter); 169 } 170 if (succeeded(result) && (!onSuccess || succeeded(onSuccess(*bestPattern)))) 171 return success(); 172 173 // Perform any necessary cleanups. 174 if (onFailure) 175 onFailure(*bestPattern); 176 } 177 return failure(); 178 } 179