1 //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements an applicator that applies pattern rewrites based upon a 10 // user defined cost model. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Rewrite/PatternApplicator.h" 15 #include "llvm/Support/Debug.h" 16 17 using namespace mlir; 18 19 #define DEBUG_TYPE "pattern-match" 20 21 void PatternApplicator::applyCostModel(CostModel model) { 22 // Separate patterns by root kind to simplify lookup later on. 23 patterns.clear(); 24 anyOpPatterns.clear(); 25 for (const auto &pat : frozenPatternList.getPatterns()) { 26 // If the pattern is always impossible to match, just ignore it. 27 if (pat.getBenefit().isImpossibleToMatch()) { 28 LLVM_DEBUG({ 29 llvm::dbgs() 30 << "Ignoring pattern '" << pat.getRootKind() 31 << "' because it is impossible to match (by pattern benefit)\n"; 32 }); 33 continue; 34 } 35 if (Optional<OperationName> opName = pat.getRootKind()) 36 patterns[*opName].push_back(&pat); 37 else 38 anyOpPatterns.push_back(&pat); 39 } 40 41 // Sort the patterns using the provided cost model. 42 llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits; 43 auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) { 44 return benefits[lhs] > benefits[rhs]; 45 }; 46 auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) { 47 // Special case for one pattern in the list, which is the most common case. 48 if (list.size() == 1) { 49 if (model(*list.front()).isImpossibleToMatch()) { 50 LLVM_DEBUG({ 51 llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind() 52 << "' because it is impossible to match or cannot lead " 53 "to legal IR (by cost model)\n"; 54 }); 55 list.clear(); 56 } 57 return; 58 } 59 60 // Collect the dynamic benefits for the current pattern list. 61 benefits.clear(); 62 for (const Pattern *pat : list) 63 benefits.try_emplace(pat, model(*pat)); 64 65 // Sort patterns with highest benefit first, and remove those that are 66 // impossible to match. 67 std::stable_sort(list.begin(), list.end(), cmp); 68 while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) { 69 LLVM_DEBUG({ 70 llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind() 71 << "' because it is impossible to match or cannot lead to " 72 "legal IR (by cost model)\n"; 73 }); 74 list.pop_back(); 75 } 76 }; 77 for (auto &it : patterns) 78 processPatternList(it.second); 79 processPatternList(anyOpPatterns); 80 } 81 82 void PatternApplicator::walkAllPatterns( 83 function_ref<void(const Pattern &)> walk) { 84 for (auto &it : frozenPatternList.getPatterns()) 85 walk(it); 86 } 87 88 LogicalResult PatternApplicator::matchAndRewrite( 89 Operation *op, PatternRewriter &rewriter, 90 function_ref<bool(const Pattern &)> canApply, 91 function_ref<void(const Pattern &)> onFailure, 92 function_ref<LogicalResult(const Pattern &)> onSuccess) { 93 // Check to see if there are patterns matching this specific operation type. 94 MutableArrayRef<const RewritePattern *> opPatterns; 95 auto patternIt = patterns.find(op->getName()); 96 if (patternIt != patterns.end()) 97 opPatterns = patternIt->second; 98 99 // Process the patterns for that match the specific operation type, and any 100 // operation type in an interleaved fashion. 101 // FIXME: It'd be nice to just write an llvm::make_merge_range utility 102 // and pass in a comparison function. That would make this code trivial. 103 auto opIt = opPatterns.begin(), opE = opPatterns.end(); 104 auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end(); 105 while (opIt != opE && anyIt != anyE) { 106 // Try to match the pattern providing the most benefit. 107 const RewritePattern *pattern; 108 if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit()) 109 pattern = *(opIt++); 110 else 111 pattern = *(anyIt++); 112 113 // Otherwise, try to match the generic pattern. 114 if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure, 115 onSuccess))) 116 return success(); 117 } 118 // If we break from the loop, then only one of the ranges can still have 119 // elements. Loop over both without checking given that we don't need to 120 // interleave anymore. 121 for (const RewritePattern *pattern : llvm::concat<const RewritePattern *>( 122 llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) { 123 if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure, 124 onSuccess))) 125 return success(); 126 } 127 return failure(); 128 } 129 130 LogicalResult PatternApplicator::matchAndRewrite( 131 Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter, 132 function_ref<bool(const Pattern &)> canApply, 133 function_ref<void(const Pattern &)> onFailure, 134 function_ref<LogicalResult(const Pattern &)> onSuccess) { 135 // Check that the pattern can be applied. 136 if (canApply && !canApply(pattern)) 137 return failure(); 138 139 // Try to match and rewrite this pattern. The patterns are sorted by 140 // benefit, so if we match we can immediately rewrite. 141 rewriter.setInsertionPoint(op); 142 if (succeeded(pattern.matchAndRewrite(op, rewriter))) 143 return success(!onSuccess || succeeded(onSuccess(pattern))); 144 145 if (onFailure) 146 onFailure(pattern); 147 return failure(); 148 } 149