1 //===- PatternApplicator.cpp - Pattern Application Engine -------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements an applicator that applies pattern rewrites based upon a
10 // user defined cost model.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Rewrite/PatternApplicator.h"
15 #include "ByteCode.h"
16 #include "llvm/Support/Debug.h"
17 
18 #define DEBUG_TYPE "pattern-match"
19 
20 using namespace mlir;
21 using namespace mlir::detail;
22 
23 PatternApplicator::PatternApplicator(
24     const FrozenRewritePatternSet &frozenPatternList)
25     : frozenPatternList(frozenPatternList) {
26   if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
27     mutableByteCodeState = std::make_unique<PDLByteCodeMutableState>();
28     bytecode->initializeMutableState(*mutableByteCodeState);
29   }
30 }
31 PatternApplicator::~PatternApplicator() {}
32 
33 /// Log a message for a pattern that is impossible to match.
34 static void logImpossibleToMatch(const Pattern &pattern) {
35   LLVM_DEBUG({
36     llvm::dbgs() << "Ignoring pattern '" << pattern.getRootKind()
37                  << "' because it is impossible to match or cannot lead "
38                     "to legal IR (by cost model)\n";
39   });
40 }
41 
42 void PatternApplicator::applyCostModel(CostModel model) {
43   // Apply the cost model to the bytecode patterns first, and then the native
44   // patterns.
45   if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
46     for (auto it : llvm::enumerate(bytecode->getPatterns()))
47       mutableByteCodeState->updatePatternBenefit(it.index(), model(it.value()));
48   }
49 
50   // Copy over the patterns so that we can sort by benefit based on the cost
51   // model. Patterns that are already impossible to match are ignored.
52   patterns.clear();
53   for (const auto &it : frozenPatternList.getOpSpecificNativePatterns()) {
54     for (const RewritePattern *pattern : it.second) {
55       if (pattern->getBenefit().isImpossibleToMatch())
56         logImpossibleToMatch(*pattern);
57       else
58         patterns[it.first].push_back(pattern);
59     }
60   }
61   anyOpPatterns.clear();
62   for (const RewritePattern &pattern :
63        frozenPatternList.getMatchAnyOpNativePatterns()) {
64     if (pattern.getBenefit().isImpossibleToMatch())
65       logImpossibleToMatch(pattern);
66     else
67       anyOpPatterns.push_back(&pattern);
68   }
69 
70   // Sort the patterns using the provided cost model.
71   llvm::SmallDenseMap<const Pattern *, PatternBenefit> benefits;
72   auto cmp = [&benefits](const Pattern *lhs, const Pattern *rhs) {
73     return benefits[lhs] > benefits[rhs];
74   };
75   auto processPatternList = [&](SmallVectorImpl<const RewritePattern *> &list) {
76     // Special case for one pattern in the list, which is the most common case.
77     if (list.size() == 1) {
78       if (model(*list.front()).isImpossibleToMatch()) {
79         logImpossibleToMatch(*list.front());
80         list.clear();
81       }
82       return;
83     }
84 
85     // Collect the dynamic benefits for the current pattern list.
86     benefits.clear();
87     for (const Pattern *pat : list)
88       benefits.try_emplace(pat, model(*pat));
89 
90     // Sort patterns with highest benefit first, and remove those that are
91     // impossible to match.
92     std::stable_sort(list.begin(), list.end(), cmp);
93     while (!list.empty() && benefits[list.back()].isImpossibleToMatch())
94       logImpossibleToMatch(*list.pop_back_val());
95   };
96   for (auto &it : patterns)
97     processPatternList(it.second);
98   processPatternList(anyOpPatterns);
99 }
100 
101 void PatternApplicator::walkAllPatterns(
102     function_ref<void(const Pattern &)> walk) {
103   for (const auto &it : frozenPatternList.getOpSpecificNativePatterns())
104     for (const auto &pattern : it.second)
105       walk(*pattern);
106   for (const Pattern &it : frozenPatternList.getMatchAnyOpNativePatterns())
107     walk(it);
108   if (const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode()) {
109     for (const Pattern &it : bytecode->getPatterns())
110       walk(it);
111   }
112 }
113 
114 LogicalResult PatternApplicator::matchAndRewrite(
115     Operation *op, PatternRewriter &rewriter,
116     function_ref<bool(const Pattern &)> canApply,
117     function_ref<void(const Pattern &)> onFailure,
118     function_ref<LogicalResult(const Pattern &)> onSuccess) {
119   // Before checking native patterns, first match against the bytecode. This
120   // won't automatically perform any rewrites so there is no need to worry about
121   // conflicts.
122   SmallVector<PDLByteCode::MatchResult, 4> pdlMatches;
123   const PDLByteCode *bytecode = frozenPatternList.getPDLByteCode();
124   if (bytecode)
125     bytecode->match(op, rewriter, pdlMatches, *mutableByteCodeState);
126 
127   // Check to see if there are patterns matching this specific operation type.
128   MutableArrayRef<const RewritePattern *> opPatterns;
129   auto patternIt = patterns.find(op->getName());
130   if (patternIt != patterns.end())
131     opPatterns = patternIt->second;
132 
133   // Process the patterns for that match the specific operation type, and any
134   // operation type in an interleaved fashion.
135   unsigned opIt = 0, opE = opPatterns.size();
136   unsigned anyIt = 0, anyE = anyOpPatterns.size();
137   unsigned pdlIt = 0, pdlE = pdlMatches.size();
138   LogicalResult result = failure();
139   do {
140     // Find the next pattern with the highest benefit.
141     const Pattern *bestPattern = nullptr;
142     unsigned *bestPatternIt = &opIt;
143     const PDLByteCode::MatchResult *pdlMatch = nullptr;
144 
145     /// Operation specific patterns.
146     if (opIt < opE)
147       bestPattern = opPatterns[opIt];
148     /// Operation agnostic patterns.
149     if (anyIt < anyE &&
150         (!bestPattern ||
151          bestPattern->getBenefit() < anyOpPatterns[anyIt]->getBenefit())) {
152       bestPatternIt = &anyIt;
153       bestPattern = anyOpPatterns[anyIt];
154     }
155     /// PDL patterns.
156     if (pdlIt < pdlE && (!bestPattern || bestPattern->getBenefit() <
157                                              pdlMatches[pdlIt].benefit)) {
158       bestPatternIt = &pdlIt;
159       pdlMatch = &pdlMatches[pdlIt];
160       bestPattern = pdlMatch->pattern;
161     }
162     if (!bestPattern)
163       break;
164 
165     // Update the pattern iterator on failure so that this pattern isn't
166     // attempted again.
167     ++(*bestPatternIt);
168 
169     // Check that the pattern can be applied.
170     if (canApply && !canApply(*bestPattern))
171       continue;
172 
173     // Try to match and rewrite this pattern. The patterns are sorted by
174     // benefit, so if we match we can immediately rewrite. For PDL patterns, the
175     // match has already been performed, we just need to rewrite.
176     rewriter.setInsertionPoint(op);
177     if (pdlMatch) {
178       bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
179       result = success(!onSuccess || succeeded(onSuccess(*bestPattern)));
180 
181     } else {
182       const auto *pattern = static_cast<const RewritePattern *>(bestPattern);
183       result = pattern->matchAndRewrite(op, rewriter);
184       if (succeeded(result) && onSuccess && failed(onSuccess(*pattern)))
185         result = failure();
186     }
187     if (succeeded(result))
188       break;
189 
190     // Perform any necessary cleanups.
191     if (onFailure)
192       onFailure(*bestPattern);
193   } while (true);
194 
195   if (mutableByteCodeState)
196     mutableByteCodeState->cleanupAfterMatchAndRewrite();
197   return result;
198 }
199