1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/IR/BlockAndValueMapping.h"
11 #include "mlir/IR/Operation.h"
12 #include "mlir/IR/Value.h"
13 #include "llvm/Support/Debug.h"
14 
15 using namespace mlir;
16 
17 #define DEBUG_TYPE "pattern-match"
18 
19 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
20   assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
21          "This pattern match benefit is too large to represent");
22 }
23 
24 unsigned short PatternBenefit::getBenefit() const {
25   assert(!isImpossibleToMatch() && "Pattern doesn't match");
26   return representation;
27 }
28 
29 //===----------------------------------------------------------------------===//
30 // Pattern implementation
31 //===----------------------------------------------------------------------===//
32 
33 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
34                  MLIRContext *context)
35     : rootKind(OperationName(rootName, context)), benefit(benefit) {}
36 Pattern::Pattern(PatternBenefit benefit, MatchAnyOpTypeTag)
37     : benefit(benefit) {}
38 
39 // Out-of-line vtable anchor.
40 void Pattern::anchor() {}
41 
42 //===----------------------------------------------------------------------===//
43 // RewritePattern and PatternRewriter implementation
44 //===----------------------------------------------------------------------===//
45 
46 void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
47   llvm_unreachable("need to implement either matchAndRewrite or one of the "
48                    "rewrite functions!");
49 }
50 
51 LogicalResult RewritePattern::match(Operation *op) const {
52   llvm_unreachable("need to implement either match or matchAndRewrite!");
53 }
54 
55 RewritePattern::RewritePattern(StringRef rootName,
56                                ArrayRef<StringRef> generatedNames,
57                                PatternBenefit benefit, MLIRContext *context)
58     : Pattern(rootName, benefit, context) {
59   generatedOps.reserve(generatedNames.size());
60   std::transform(generatedNames.begin(), generatedNames.end(),
61                  std::back_inserter(generatedOps), [context](StringRef name) {
62                    return OperationName(name, context);
63                  });
64 }
65 RewritePattern::RewritePattern(ArrayRef<StringRef> generatedNames,
66                                PatternBenefit benefit, MLIRContext *context,
67                                MatchAnyOpTypeTag tag)
68     : Pattern(benefit, tag) {
69   generatedOps.reserve(generatedNames.size());
70   std::transform(generatedNames.begin(), generatedNames.end(),
71                  std::back_inserter(generatedOps), [context](StringRef name) {
72                    return OperationName(name, context);
73                  });
74 }
75 
76 PatternRewriter::~PatternRewriter() {
77   // Out of line to provide a vtable anchor for the class.
78 }
79 
80 /// This method performs the final replacement for a pattern, where the
81 /// results of the operation are updated to use the specified list of SSA
82 /// values.
83 void PatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
84   // Notify the rewriter subclass that we're about to replace this root.
85   notifyRootReplaced(op);
86 
87   assert(op->getNumResults() == newValues.size() &&
88          "incorrect # of replacement values");
89   op->replaceAllUsesWith(newValues);
90 
91   notifyOperationRemoved(op);
92   op->erase();
93 }
94 
95 /// This method erases an operation that is known to have no uses. The uses of
96 /// the given operation *must* be known to be dead.
97 void PatternRewriter::eraseOp(Operation *op) {
98   assert(op->use_empty() && "expected 'op' to have no uses");
99   notifyOperationRemoved(op);
100   op->erase();
101 }
102 
103 void PatternRewriter::eraseBlock(Block *block) {
104   for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
105     assert(op.use_empty() && "expected 'op' to have no uses");
106     eraseOp(&op);
107   }
108   block->erase();
109 }
110 
111 /// Merge the operations of block 'source' into the end of block 'dest'.
112 /// 'source's predecessors must be empty or only contain 'dest`.
113 /// 'argValues' is used to replace the block arguments of 'source' after
114 /// merging.
115 void PatternRewriter::mergeBlocks(Block *source, Block *dest,
116                                   ValueRange argValues) {
117   assert(llvm::all_of(source->getPredecessors(),
118                       [dest](Block *succ) { return succ == dest; }) &&
119          "expected 'source' to have no predecessors or only 'dest'");
120   assert(argValues.size() == source->getNumArguments() &&
121          "incorrect # of argument replacement values");
122 
123   // Replace all of the successor arguments with the provided values.
124   for (auto it : llvm::zip(source->getArguments(), argValues))
125     std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
126 
127   // Splice the operations of the 'source' block into the 'dest' block and erase
128   // it.
129   dest->getOperations().splice(dest->end(), source->getOperations());
130   source->dropAllUses();
131   source->erase();
132 }
133 
134 // Merge the operations of block 'source' before the operation 'op'. Source
135 // block should not have existing predecessors or successors.
136 void PatternRewriter::mergeBlockBefore(Block *source, Operation *op,
137                                        ValueRange argValues) {
138   assert(source->hasNoPredecessors() &&
139          "expected 'source' to have no predecessors");
140   assert(source->hasNoSuccessors() &&
141          "expected 'source' to have no successors");
142 
143   // Split the block containing 'op' into two, one containg all operations
144   // before 'op' (prologue) and another (epilogue) containing 'op' and all
145   // operations after it.
146   Block *prologue = op->getBlock();
147   Block *epilogue = splitBlock(prologue, op->getIterator());
148 
149   // Merge the source block at the end of the prologue.
150   mergeBlocks(source, prologue, argValues);
151 
152   // Merge the epilogue at the end the prologue.
153   mergeBlocks(epilogue, prologue);
154 }
155 
156 /// Split the operations starting at "before" (inclusive) out of the given
157 /// block into a new block, and return it.
158 Block *PatternRewriter::splitBlock(Block *block, Block::iterator before) {
159   return block->splitBlock(before);
160 }
161 
162 /// 'op' and 'newOp' are known to have the same number of results, replace the
163 /// uses of op with uses of newOp
164 void PatternRewriter::replaceOpWithResultsOfAnotherOp(Operation *op,
165                                                       Operation *newOp) {
166   assert(op->getNumResults() == newOp->getNumResults() &&
167          "replacement op doesn't match results of original op");
168   if (op->getNumResults() == 1)
169     return replaceOp(op, newOp->getResult(0));
170   return replaceOp(op, newOp->getResults());
171 }
172 
173 /// Move the blocks that belong to "region" before the given position in
174 /// another region.  The two regions must be different.  The caller is in
175 /// charge to update create the operation transferring the control flow to the
176 /// region and pass it the correct block arguments.
177 void PatternRewriter::inlineRegionBefore(Region &region, Region &parent,
178                                          Region::iterator before) {
179   parent.getBlocks().splice(before, region.getBlocks());
180 }
181 void PatternRewriter::inlineRegionBefore(Region &region, Block *before) {
182   inlineRegionBefore(region, *before->getParent(), before->getIterator());
183 }
184 
185 /// Clone the blocks that belong to "region" before the given position in
186 /// another region "parent". The two regions must be different. The caller is
187 /// responsible for creating or updating the operation transferring flow of
188 /// control to the region and passing it the correct block arguments.
189 void PatternRewriter::cloneRegionBefore(Region &region, Region &parent,
190                                         Region::iterator before,
191                                         BlockAndValueMapping &mapping) {
192   region.cloneInto(&parent, before, mapping);
193 }
194 void PatternRewriter::cloneRegionBefore(Region &region, Region &parent,
195                                         Region::iterator before) {
196   BlockAndValueMapping mapping;
197   cloneRegionBefore(region, parent, before, mapping);
198 }
199 void PatternRewriter::cloneRegionBefore(Region &region, Block *before) {
200   cloneRegionBefore(region, *before->getParent(), before->getIterator());
201 }
202 
203 //===----------------------------------------------------------------------===//
204 // PatternMatcher implementation
205 //===----------------------------------------------------------------------===//
206 
207 void PatternApplicator::applyCostModel(CostModel model) {
208   // Separate patterns by root kind to simplify lookup later on.
209   patterns.clear();
210   anyOpPatterns.clear();
211   for (const auto &pat : owningPatternList) {
212     // If the pattern is always impossible to match, just ignore it.
213     if (pat->getBenefit().isImpossibleToMatch()) {
214       LLVM_DEBUG({
215         llvm::dbgs()
216             << "Ignoring pattern '" << pat->getRootKind()
217             << "' because it is impossible to match (by pattern benefit)\n";
218       });
219       continue;
220     }
221     if (Optional<OperationName> opName = pat->getRootKind())
222       patterns[*opName].push_back(pat.get());
223     else
224       anyOpPatterns.push_back(pat.get());
225   }
226 
227   // Sort the patterns using the provided cost model.
228   llvm::SmallDenseMap<RewritePattern *, PatternBenefit> benefits;
229   auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) {
230     return benefits[lhs] > benefits[rhs];
231   };
232   auto processPatternList = [&](SmallVectorImpl<RewritePattern *> &list) {
233     // Special case for one pattern in the list, which is the most common case.
234     if (list.size() == 1) {
235       if (model(*list.front()).isImpossibleToMatch()) {
236         LLVM_DEBUG({
237           llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind()
238                        << "' because it is impossible to match or cannot lead "
239                           "to legal IR (by cost model)\n";
240         });
241         list.clear();
242       }
243       return;
244     }
245 
246     // Collect the dynamic benefits for the current pattern list.
247     benefits.clear();
248     for (RewritePattern *pat : list)
249       benefits.try_emplace(pat, model(*pat));
250 
251     // Sort patterns with highest benefit first, and remove those that are
252     // impossible to match.
253     std::stable_sort(list.begin(), list.end(), cmp);
254     while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
255       LLVM_DEBUG({
256         llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind()
257                      << "' because it is impossible to match or cannot lead to "
258                         "legal IR (by cost model)\n";
259       });
260       list.pop_back();
261     }
262   };
263   for (auto &it : patterns)
264     processPatternList(it.second);
265   processPatternList(anyOpPatterns);
266 }
267 
268 void PatternApplicator::walkAllPatterns(
269     function_ref<void(const RewritePattern &)> walk) {
270   for (auto &it : owningPatternList)
271     walk(*it);
272 }
273 
274 LogicalResult PatternApplicator::matchAndRewrite(
275     Operation *op, PatternRewriter &rewriter,
276     function_ref<bool(const RewritePattern &)> canApply,
277     function_ref<void(const RewritePattern &)> onFailure,
278     function_ref<LogicalResult(const RewritePattern &)> onSuccess) {
279   // Check to see if there are patterns matching this specific operation type.
280   MutableArrayRef<RewritePattern *> opPatterns;
281   auto patternIt = patterns.find(op->getName());
282   if (patternIt != patterns.end())
283     opPatterns = patternIt->second;
284 
285   // Process the patterns for that match the specific operation type, and any
286   // operation type in an interleaved fashion.
287   // FIXME: It'd be nice to just write an llvm::make_merge_range utility
288   // and pass in a comparison function. That would make this code trivial.
289   auto opIt = opPatterns.begin(), opE = opPatterns.end();
290   auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
291   while (opIt != opE && anyIt != anyE) {
292     // Try to match the pattern providing the most benefit.
293     RewritePattern *pattern;
294     if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit())
295       pattern = *(opIt++);
296     else
297       pattern = *(anyIt++);
298 
299     // Otherwise, try to match the generic pattern.
300     if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
301                                   onSuccess)))
302       return success();
303   }
304   // If we break from the loop, then only one of the ranges can still have
305   // elements. Loop over both without checking given that we don't need to
306   // interleave anymore.
307   for (RewritePattern *pattern : llvm::concat<RewritePattern *>(
308            llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) {
309     if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
310                                   onSuccess)))
311       return success();
312   }
313   return failure();
314 }
315 
316 LogicalResult PatternApplicator::matchAndRewrite(
317     Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
318     function_ref<bool(const RewritePattern &)> canApply,
319     function_ref<void(const RewritePattern &)> onFailure,
320     function_ref<LogicalResult(const RewritePattern &)> onSuccess) {
321   // Check that the pattern can be applied.
322   if (canApply && !canApply(pattern))
323     return failure();
324 
325   // Try to match and rewrite this pattern. The patterns are sorted by
326   // benefit, so if we match we can immediately rewrite.
327   rewriter.setInsertionPoint(op);
328   if (succeeded(pattern.matchAndRewrite(op, rewriter)))
329     return success(!onSuccess || succeeded(onSuccess(pattern)));
330 
331   if (onFailure)
332     onFailure(pattern);
333   return failure();
334 }
335