1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/IR/BlockAndValueMapping.h"
11 #include "mlir/IR/Operation.h"
12 #include "mlir/IR/Value.h"
13 #include "llvm/Support/Debug.h"
14 
15 using namespace mlir;
16 
17 #define DEBUG_TYPE "pattern-match"
18 
19 //===----------------------------------------------------------------------===//
20 // PatternBenefit
21 //===----------------------------------------------------------------------===//
22 
23 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
24   assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
25          "This pattern match benefit is too large to represent");
26 }
27 
28 unsigned short PatternBenefit::getBenefit() const {
29   assert(!isImpossibleToMatch() && "Pattern doesn't match");
30   return representation;
31 }
32 
33 //===----------------------------------------------------------------------===//
34 // Pattern
35 //===----------------------------------------------------------------------===//
36 
37 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
38                  MLIRContext *context)
39     : rootKind(OperationName(rootName, context)), benefit(benefit) {}
40 Pattern::Pattern(PatternBenefit benefit, MatchAnyOpTypeTag tag)
41     : benefit(benefit) {}
42 Pattern::Pattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
43                  PatternBenefit benefit, MLIRContext *context)
44     : Pattern(rootName, benefit, context) {
45   generatedOps.reserve(generatedNames.size());
46   std::transform(generatedNames.begin(), generatedNames.end(),
47                  std::back_inserter(generatedOps), [context](StringRef name) {
48                    return OperationName(name, context);
49                  });
50 }
51 Pattern::Pattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
52                  MLIRContext *context, MatchAnyOpTypeTag tag)
53     : Pattern(benefit, tag) {
54   generatedOps.reserve(generatedNames.size());
55   std::transform(generatedNames.begin(), generatedNames.end(),
56                  std::back_inserter(generatedOps), [context](StringRef name) {
57                    return OperationName(name, context);
58                  });
59 }
60 
61 //===----------------------------------------------------------------------===//
62 // RewritePattern
63 //===----------------------------------------------------------------------===//
64 
65 void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
66   llvm_unreachable("need to implement either matchAndRewrite or one of the "
67                    "rewrite functions!");
68 }
69 
70 LogicalResult RewritePattern::match(Operation *op) const {
71   llvm_unreachable("need to implement either match or matchAndRewrite!");
72 }
73 
74 /// Out-of-line vtable anchor.
75 void RewritePattern::anchor() {}
76 
77 //===----------------------------------------------------------------------===//
78 // PatternRewriter
79 //===----------------------------------------------------------------------===//
80 
81 PatternRewriter::~PatternRewriter() {
82   // Out of line to provide a vtable anchor for the class.
83 }
84 
85 /// This method performs the final replacement for a pattern, where the
86 /// results of the operation are updated to use the specified list of SSA
87 /// values.
88 void PatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
89   // Notify the rewriter subclass that we're about to replace this root.
90   notifyRootReplaced(op);
91 
92   assert(op->getNumResults() == newValues.size() &&
93          "incorrect # of replacement values");
94   op->replaceAllUsesWith(newValues);
95 
96   notifyOperationRemoved(op);
97   op->erase();
98 }
99 
100 /// This method erases an operation that is known to have no uses. The uses of
101 /// the given operation *must* be known to be dead.
102 void PatternRewriter::eraseOp(Operation *op) {
103   assert(op->use_empty() && "expected 'op' to have no uses");
104   notifyOperationRemoved(op);
105   op->erase();
106 }
107 
108 void PatternRewriter::eraseBlock(Block *block) {
109   for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
110     assert(op.use_empty() && "expected 'op' to have no uses");
111     eraseOp(&op);
112   }
113   block->erase();
114 }
115 
116 /// Merge the operations of block 'source' into the end of block 'dest'.
117 /// 'source's predecessors must be empty or only contain 'dest`.
118 /// 'argValues' is used to replace the block arguments of 'source' after
119 /// merging.
120 void PatternRewriter::mergeBlocks(Block *source, Block *dest,
121                                   ValueRange argValues) {
122   assert(llvm::all_of(source->getPredecessors(),
123                       [dest](Block *succ) { return succ == dest; }) &&
124          "expected 'source' to have no predecessors or only 'dest'");
125   assert(argValues.size() == source->getNumArguments() &&
126          "incorrect # of argument replacement values");
127 
128   // Replace all of the successor arguments with the provided values.
129   for (auto it : llvm::zip(source->getArguments(), argValues))
130     std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
131 
132   // Splice the operations of the 'source' block into the 'dest' block and erase
133   // it.
134   dest->getOperations().splice(dest->end(), source->getOperations());
135   source->dropAllUses();
136   source->erase();
137 }
138 
139 // Merge the operations of block 'source' before the operation 'op'. Source
140 // block should not have existing predecessors or successors.
141 void PatternRewriter::mergeBlockBefore(Block *source, Operation *op,
142                                        ValueRange argValues) {
143   assert(source->hasNoPredecessors() &&
144          "expected 'source' to have no predecessors");
145   assert(source->hasNoSuccessors() &&
146          "expected 'source' to have no successors");
147 
148   // Split the block containing 'op' into two, one containg all operations
149   // before 'op' (prologue) and another (epilogue) containing 'op' and all
150   // operations after it.
151   Block *prologue = op->getBlock();
152   Block *epilogue = splitBlock(prologue, op->getIterator());
153 
154   // Merge the source block at the end of the prologue.
155   mergeBlocks(source, prologue, argValues);
156 
157   // Merge the epilogue at the end the prologue.
158   mergeBlocks(epilogue, prologue);
159 }
160 
161 /// Split the operations starting at "before" (inclusive) out of the given
162 /// block into a new block, and return it.
163 Block *PatternRewriter::splitBlock(Block *block, Block::iterator before) {
164   return block->splitBlock(before);
165 }
166 
167 /// 'op' and 'newOp' are known to have the same number of results, replace the
168 /// uses of op with uses of newOp
169 void PatternRewriter::replaceOpWithResultsOfAnotherOp(Operation *op,
170                                                       Operation *newOp) {
171   assert(op->getNumResults() == newOp->getNumResults() &&
172          "replacement op doesn't match results of original op");
173   if (op->getNumResults() == 1)
174     return replaceOp(op, newOp->getResult(0));
175   return replaceOp(op, newOp->getResults());
176 }
177 
178 /// Move the blocks that belong to "region" before the given position in
179 /// another region.  The two regions must be different.  The caller is in
180 /// charge to update create the operation transferring the control flow to the
181 /// region and pass it the correct block arguments.
182 void PatternRewriter::inlineRegionBefore(Region &region, Region &parent,
183                                          Region::iterator before) {
184   parent.getBlocks().splice(before, region.getBlocks());
185 }
186 void PatternRewriter::inlineRegionBefore(Region &region, Block *before) {
187   inlineRegionBefore(region, *before->getParent(), before->getIterator());
188 }
189 
190 /// Clone the blocks that belong to "region" before the given position in
191 /// another region "parent". The two regions must be different. The caller is
192 /// responsible for creating or updating the operation transferring flow of
193 /// control to the region and passing it the correct block arguments.
194 void PatternRewriter::cloneRegionBefore(Region &region, Region &parent,
195                                         Region::iterator before,
196                                         BlockAndValueMapping &mapping) {
197   region.cloneInto(&parent, before, mapping);
198 }
199 void PatternRewriter::cloneRegionBefore(Region &region, Region &parent,
200                                         Region::iterator before) {
201   BlockAndValueMapping mapping;
202   cloneRegionBefore(region, parent, before, mapping);
203 }
204 void PatternRewriter::cloneRegionBefore(Region &region, Block *before) {
205   cloneRegionBefore(region, *before->getParent(), before->getIterator());
206 }
207 
208 //===----------------------------------------------------------------------===//
209 // PatternApplicator
210 //===----------------------------------------------------------------------===//
211 
212 void PatternApplicator::applyCostModel(CostModel model) {
213   // Separate patterns by root kind to simplify lookup later on.
214   patterns.clear();
215   anyOpPatterns.clear();
216   for (const auto &pat : owningPatternList) {
217     // If the pattern is always impossible to match, just ignore it.
218     if (pat->getBenefit().isImpossibleToMatch()) {
219       LLVM_DEBUG({
220         llvm::dbgs()
221             << "Ignoring pattern '" << pat->getRootKind()
222             << "' because it is impossible to match (by pattern benefit)\n";
223       });
224       continue;
225     }
226     if (Optional<OperationName> opName = pat->getRootKind())
227       patterns[*opName].push_back(pat.get());
228     else
229       anyOpPatterns.push_back(pat.get());
230   }
231 
232   // Sort the patterns using the provided cost model.
233   llvm::SmallDenseMap<RewritePattern *, PatternBenefit> benefits;
234   auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) {
235     return benefits[lhs] > benefits[rhs];
236   };
237   auto processPatternList = [&](SmallVectorImpl<RewritePattern *> &list) {
238     // Special case for one pattern in the list, which is the most common case.
239     if (list.size() == 1) {
240       if (model(*list.front()).isImpossibleToMatch()) {
241         LLVM_DEBUG({
242           llvm::dbgs() << "Ignoring pattern '" << list.front()->getRootKind()
243                        << "' because it is impossible to match or cannot lead "
244                           "to legal IR (by cost model)\n";
245         });
246         list.clear();
247       }
248       return;
249     }
250 
251     // Collect the dynamic benefits for the current pattern list.
252     benefits.clear();
253     for (RewritePattern *pat : list)
254       benefits.try_emplace(pat, model(*pat));
255 
256     // Sort patterns with highest benefit first, and remove those that are
257     // impossible to match.
258     std::stable_sort(list.begin(), list.end(), cmp);
259     while (!list.empty() && benefits[list.back()].isImpossibleToMatch()) {
260       LLVM_DEBUG({
261         llvm::dbgs() << "Ignoring pattern '" << list.back()->getRootKind()
262                      << "' because it is impossible to match or cannot lead to "
263                         "legal IR (by cost model)\n";
264       });
265       list.pop_back();
266     }
267   };
268   for (auto &it : patterns)
269     processPatternList(it.second);
270   processPatternList(anyOpPatterns);
271 }
272 
273 void PatternApplicator::walkAllPatterns(
274     function_ref<void(const Pattern &)> walk) {
275   for (auto &it : owningPatternList)
276     walk(*it);
277 }
278 
279 LogicalResult PatternApplicator::matchAndRewrite(
280     Operation *op, PatternRewriter &rewriter,
281     function_ref<bool(const Pattern &)> canApply,
282     function_ref<void(const Pattern &)> onFailure,
283     function_ref<LogicalResult(const Pattern &)> onSuccess) {
284   // Check to see if there are patterns matching this specific operation type.
285   MutableArrayRef<RewritePattern *> opPatterns;
286   auto patternIt = patterns.find(op->getName());
287   if (patternIt != patterns.end())
288     opPatterns = patternIt->second;
289 
290   // Process the patterns for that match the specific operation type, and any
291   // operation type in an interleaved fashion.
292   // FIXME: It'd be nice to just write an llvm::make_merge_range utility
293   // and pass in a comparison function. That would make this code trivial.
294   auto opIt = opPatterns.begin(), opE = opPatterns.end();
295   auto anyIt = anyOpPatterns.begin(), anyE = anyOpPatterns.end();
296   while (opIt != opE && anyIt != anyE) {
297     // Try to match the pattern providing the most benefit.
298     RewritePattern *pattern;
299     if ((*opIt)->getBenefit() >= (*anyIt)->getBenefit())
300       pattern = *(opIt++);
301     else
302       pattern = *(anyIt++);
303 
304     // Otherwise, try to match the generic pattern.
305     if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
306                                   onSuccess)))
307       return success();
308   }
309   // If we break from the loop, then only one of the ranges can still have
310   // elements. Loop over both without checking given that we don't need to
311   // interleave anymore.
312   for (RewritePattern *pattern : llvm::concat<RewritePattern *>(
313            llvm::make_range(opIt, opE), llvm::make_range(anyIt, anyE))) {
314     if (succeeded(matchAndRewrite(op, *pattern, rewriter, canApply, onFailure,
315                                   onSuccess)))
316       return success();
317   }
318   return failure();
319 }
320 
321 LogicalResult PatternApplicator::matchAndRewrite(
322     Operation *op, const RewritePattern &pattern, PatternRewriter &rewriter,
323     function_ref<bool(const Pattern &)> canApply,
324     function_ref<void(const Pattern &)> onFailure,
325     function_ref<LogicalResult(const Pattern &)> onSuccess) {
326   // Check that the pattern can be applied.
327   if (canApply && !canApply(pattern))
328     return failure();
329 
330   // Try to match and rewrite this pattern. The patterns are sorted by
331   // benefit, so if we match we can immediately rewrite.
332   rewriter.setInsertionPoint(op);
333   if (succeeded(pattern.matchAndRewrite(op, rewriter)))
334     return success(!onSuccess || succeeded(onSuccess(pattern)));
335 
336   if (onFailure)
337     onFailure(pattern);
338   return failure();
339 }
340