1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/IR/BlockAndValueMapping.h"
11 #include "mlir/IR/Operation.h"
12 #include "mlir/IR/Value.h"
13 
14 using namespace mlir;
15 
16 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
17   assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
18          "This pattern match benefit is too large to represent");
19 }
20 
21 unsigned short PatternBenefit::getBenefit() const {
22   assert(!isImpossibleToMatch() && "Pattern doesn't match");
23   return representation;
24 }
25 
26 //===----------------------------------------------------------------------===//
27 // Pattern implementation
28 //===----------------------------------------------------------------------===//
29 
30 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
31                  MLIRContext *context)
32     : rootKind(OperationName(rootName, context)), benefit(benefit) {}
33 
34 // Out-of-line vtable anchor.
35 void Pattern::anchor() {}
36 
37 //===----------------------------------------------------------------------===//
38 // RewritePattern and PatternRewriter implementation
39 //===----------------------------------------------------------------------===//
40 
41 void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
42   llvm_unreachable("need to implement either matchAndRewrite or one of the "
43                    "rewrite functions!");
44 }
45 
46 LogicalResult RewritePattern::match(Operation *op) const {
47   llvm_unreachable("need to implement either match or matchAndRewrite!");
48 }
49 
50 /// Patterns must specify the root operation name they match against, and can
51 /// also specify the benefit of the pattern matching. They can also specify the
52 /// names of operations that may be generated during a successful rewrite.
53 RewritePattern::RewritePattern(StringRef rootName,
54                                ArrayRef<StringRef> generatedNames,
55                                PatternBenefit benefit, MLIRContext *context)
56     : Pattern(rootName, benefit, context) {
57   generatedOps.reserve(generatedNames.size());
58   std::transform(generatedNames.begin(), generatedNames.end(),
59                  std::back_inserter(generatedOps), [context](StringRef name) {
60                    return OperationName(name, context);
61                  });
62 }
63 
64 PatternRewriter::~PatternRewriter() {
65   // Out of line to provide a vtable anchor for the class.
66 }
67 
68 /// This method performs the final replacement for a pattern, where the
69 /// results of the operation are updated to use the specified list of SSA
70 /// values.
71 void PatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
72   // Notify the rewriter subclass that we're about to replace this root.
73   notifyRootReplaced(op);
74 
75   assert(op->getNumResults() == newValues.size() &&
76          "incorrect # of replacement values");
77   op->replaceAllUsesWith(newValues);
78 
79   notifyOperationRemoved(op);
80   op->erase();
81 }
82 
83 /// This method erases an operation that is known to have no uses. The uses of
84 /// the given operation *must* be known to be dead.
85 void PatternRewriter::eraseOp(Operation *op) {
86   assert(op->use_empty() && "expected 'op' to have no uses");
87   notifyOperationRemoved(op);
88   op->erase();
89 }
90 
91 void PatternRewriter::eraseBlock(Block *block) {
92   for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
93     assert(op.use_empty() && "expected 'op' to have no uses");
94     eraseOp(&op);
95   }
96   block->erase();
97 }
98 
99 /// Merge the operations of block 'source' into the end of block 'dest'.
100 /// 'source's predecessors must be empty or only contain 'dest`.
101 /// 'argValues' is used to replace the block arguments of 'source' after
102 /// merging.
103 void PatternRewriter::mergeBlocks(Block *source, Block *dest,
104                                   ValueRange argValues) {
105   assert(llvm::all_of(source->getPredecessors(),
106                       [dest](Block *succ) { return succ == dest; }) &&
107          "expected 'source' to have no predecessors or only 'dest'");
108   assert(argValues.size() == source->getNumArguments() &&
109          "incorrect # of argument replacement values");
110 
111   // Replace all of the successor arguments with the provided values.
112   for (auto it : llvm::zip(source->getArguments(), argValues))
113     std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
114 
115   // Splice the operations of the 'source' block into the 'dest' block and erase
116   // it.
117   dest->getOperations().splice(dest->end(), source->getOperations());
118   source->dropAllUses();
119   source->erase();
120 }
121 
122 /// Split the operations starting at "before" (inclusive) out of the given
123 /// block into a new block, and return it.
124 Block *PatternRewriter::splitBlock(Block *block, Block::iterator before) {
125   return block->splitBlock(before);
126 }
127 
128 /// 'op' and 'newOp' are known to have the same number of results, replace the
129 /// uses of op with uses of newOp
130 void PatternRewriter::replaceOpWithResultsOfAnotherOp(Operation *op,
131                                                       Operation *newOp) {
132   assert(op->getNumResults() == newOp->getNumResults() &&
133          "replacement op doesn't match results of original op");
134   if (op->getNumResults() == 1)
135     return replaceOp(op, newOp->getResult(0));
136   return replaceOp(op, newOp->getResults());
137 }
138 
139 /// Move the blocks that belong to "region" before the given position in
140 /// another region.  The two regions must be different.  The caller is in
141 /// charge to update create the operation transferring the control flow to the
142 /// region and pass it the correct block arguments.
143 void PatternRewriter::inlineRegionBefore(Region &region, Region &parent,
144                                          Region::iterator before) {
145   parent.getBlocks().splice(before, region.getBlocks());
146 }
147 void PatternRewriter::inlineRegionBefore(Region &region, Block *before) {
148   inlineRegionBefore(region, *before->getParent(), before->getIterator());
149 }
150 
151 /// Clone the blocks that belong to "region" before the given position in
152 /// another region "parent". The two regions must be different. The caller is
153 /// responsible for creating or updating the operation transferring flow of
154 /// control to the region and passing it the correct block arguments.
155 void PatternRewriter::cloneRegionBefore(Region &region, Region &parent,
156                                         Region::iterator before,
157                                         BlockAndValueMapping &mapping) {
158   region.cloneInto(&parent, before, mapping);
159 }
160 void PatternRewriter::cloneRegionBefore(Region &region, Region &parent,
161                                         Region::iterator before) {
162   BlockAndValueMapping mapping;
163   cloneRegionBefore(region, parent, before, mapping);
164 }
165 void PatternRewriter::cloneRegionBefore(Region &region, Block *before) {
166   cloneRegionBefore(region, *before->getParent(), before->getIterator());
167 }
168 
169 //===----------------------------------------------------------------------===//
170 // PatternMatcher implementation
171 //===----------------------------------------------------------------------===//
172 
173 void PatternApplicator::applyCostModel(CostModel model) {
174   // Separate patterns by root kind to simplify lookup later on.
175   patterns.clear();
176   for (const auto &pat : owningPatternList)
177     patterns[pat->getRootKind()].push_back(pat.get());
178 
179   // Sort the patterns using the provided cost model.
180   llvm::SmallDenseMap<RewritePattern *, PatternBenefit> benefits;
181   auto cmp = [&benefits](RewritePattern *lhs, RewritePattern *rhs) {
182     return benefits[lhs] > benefits[rhs];
183   };
184   for (auto &it : patterns) {
185     SmallVectorImpl<RewritePattern *> &list = it.second;
186 
187     // Special case for one pattern in the list, which is the most common case.
188     if (list.size() == 1) {
189       if (model(*list.front()).isImpossibleToMatch())
190         list.clear();
191       continue;
192     }
193 
194     // Collect the dynamic benefits for the current pattern list.
195     benefits.clear();
196     for (RewritePattern *pat : list)
197       benefits.try_emplace(pat, model(*pat));
198 
199     // Sort patterns with highest benefit first, and remove those that are
200     // impossible to match.
201     std::stable_sort(list.begin(), list.end(), cmp);
202     while (!list.empty() && benefits[list.back()].isImpossibleToMatch())
203       list.pop_back();
204   }
205 }
206 
207 void PatternApplicator::walkAllPatterns(
208     function_ref<void(const RewritePattern &)> walk) {
209   for (auto &it : owningPatternList)
210     walk(*it);
211 }
212 
213 /// Try to match the given operation to a pattern and rewrite it.
214 LogicalResult PatternApplicator::matchAndRewrite(
215     Operation *op, PatternRewriter &rewriter,
216     function_ref<bool(const RewritePattern &)> canApply,
217     function_ref<void(const RewritePattern &)> onFailure,
218     function_ref<LogicalResult(const RewritePattern &)> onSuccess) {
219   auto patternIt = patterns.find(op->getName());
220   if (patternIt == patterns.end())
221     return failure();
222 
223   for (auto *pattern : patternIt->second) {
224     // Check that the pattern can be applied.
225     if (canApply && !canApply(*pattern))
226       continue;
227 
228     // Try to match and rewrite this pattern. The patterns are sorted by
229     // benefit, so if we match we can immediately rewrite.
230     rewriter.setInsertionPoint(op);
231     if (succeeded(pattern->matchAndRewrite(op, rewriter))) {
232       if (!onSuccess || succeeded(onSuccess(*pattern)))
233         return success();
234       continue;
235     }
236 
237     if (onFailure)
238       onFailure(*pattern);
239   }
240   return failure();
241 }
242