1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/BlockAndValueMapping.h"
20 #include "mlir/IR/Operation.h"
21 #include "mlir/IR/Value.h"
22 using namespace mlir;
23 
24 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
25   assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
26          "This pattern match benefit is too large to represent");
27 }
28 
29 unsigned short PatternBenefit::getBenefit() const {
30   assert(representation != ImpossibleToMatchSentinel &&
31          "Pattern doesn't match");
32   return representation;
33 }
34 
35 //===----------------------------------------------------------------------===//
36 // Pattern implementation
37 //===----------------------------------------------------------------------===//
38 
39 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
40                  MLIRContext *context)
41     : rootKind(OperationName(rootName, context)), benefit(benefit) {}
42 
43 // Out-of-line vtable anchor.
44 void Pattern::anchor() {}
45 
46 //===----------------------------------------------------------------------===//
47 // RewritePattern and PatternRewriter implementation
48 //===----------------------------------------------------------------------===//
49 
50 void RewritePattern::rewrite(Operation *op, std::unique_ptr<PatternState> state,
51                              PatternRewriter &rewriter) const {
52   rewrite(op, rewriter);
53 }
54 
55 void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
56   llvm_unreachable("need to implement either matchAndRewrite or one of the "
57                    "rewrite functions!");
58 }
59 
60 PatternMatchResult RewritePattern::match(Operation *op) const {
61   llvm_unreachable("need to implement either match or matchAndRewrite!");
62 }
63 
64 /// Patterns must specify the root operation name they match against, and can
65 /// also specify the benefit of the pattern matching. They can also specify the
66 /// names of operations that may be generated during a successful rewrite.
67 RewritePattern::RewritePattern(StringRef rootName,
68                                ArrayRef<StringRef> generatedNames,
69                                PatternBenefit benefit, MLIRContext *context)
70     : Pattern(rootName, benefit, context) {
71   generatedOps.reserve(generatedNames.size());
72   std::transform(generatedNames.begin(), generatedNames.end(),
73                  std::back_inserter(generatedOps), [context](StringRef name) {
74                    return OperationName(name, context);
75                  });
76 }
77 
78 PatternRewriter::~PatternRewriter() {
79   // Out of line to provide a vtable anchor for the class.
80 }
81 
82 /// This method performs the final replacement for a pattern, where the
83 /// results of the operation are updated to use the specified list of SSA
84 /// values.  In addition to replacing and removing the specified operation,
85 /// clients can specify a list of other nodes that this replacement may make
86 /// (perhaps transitively) dead.  If any of those ops are dead, this will
87 /// remove them as well.
88 void PatternRewriter::replaceOp(Operation *op, ValueRange newValues,
89                                 ValueRange valuesToRemoveIfDead) {
90   // Notify the rewriter subclass that we're about to replace this root.
91   notifyRootReplaced(op);
92 
93   assert(op->getNumResults() == newValues.size() &&
94          "incorrect # of replacement values");
95   op->replaceAllUsesWith(newValues);
96 
97   notifyOperationRemoved(op);
98   op->erase();
99 
100   // TODO: Process the valuesToRemoveIfDead list, removing things and calling
101   // the notifyOperationRemoved hook in the process.
102 }
103 
104 /// This method erases an operation that is known to have no uses. The uses of
105 /// the given operation *must* be known to be dead.
106 void PatternRewriter::eraseOp(Operation *op) {
107   assert(op->use_empty() && "expected 'op' to have no uses");
108   notifyOperationRemoved(op);
109   op->erase();
110 }
111 
112 /// Merge the operations of block 'source' into the end of block 'dest'.
113 /// 'source's predecessors must be empty or only contain 'dest`.
114 /// 'argValues' is used to replace the block arguments of 'source' after
115 /// merging.
116 void PatternRewriter::mergeBlocks(Block *source, Block *dest,
117                                   ValueRange argValues) {
118   assert(llvm::all_of(source->getPredecessors(),
119                       [dest](Block *succ) { return succ == dest; }) &&
120          "expected 'source' to have no predecessors or only 'dest'");
121   assert(argValues.size() == source->getNumArguments() &&
122          "incorrect # of argument replacement values");
123 
124   // Replace all of the successor arguments with the provided values.
125   for (auto it : llvm::zip(source->getArguments(), argValues))
126     std::get<0>(it)->replaceAllUsesWith(std::get<1>(it));
127 
128   // Splice the operations of the 'source' block into the 'dest' block and erase
129   // it.
130   dest->getOperations().splice(dest->end(), source->getOperations());
131   source->dropAllUses();
132   source->erase();
133 }
134 
135 /// Split the operations starting at "before" (inclusive) out of the given
136 /// block into a new block, and return it.
137 Block *PatternRewriter::splitBlock(Block *block, Block::iterator before) {
138   return block->splitBlock(before);
139 }
140 
141 /// op and newOp are known to have the same number of results, replace the
142 /// uses of op with uses of newOp
143 void PatternRewriter::replaceOpWithResultsOfAnotherOp(
144     Operation *op, Operation *newOp, ValueRange valuesToRemoveIfDead) {
145   assert(op->getNumResults() == newOp->getNumResults() &&
146          "replacement op doesn't match results of original op");
147   if (op->getNumResults() == 1)
148     return replaceOp(op, newOp->getResult(0), valuesToRemoveIfDead);
149   return replaceOp(op, newOp->getResults(), valuesToRemoveIfDead);
150 }
151 
152 /// Move the blocks that belong to "region" before the given position in
153 /// another region.  The two regions must be different.  The caller is in
154 /// charge to update create the operation transferring the control flow to the
155 /// region and pass it the correct block arguments.
156 void PatternRewriter::inlineRegionBefore(Region &region, Region &parent,
157                                          Region::iterator before) {
158   parent.getBlocks().splice(before, region.getBlocks());
159 }
160 void PatternRewriter::inlineRegionBefore(Region &region, Block *before) {
161   inlineRegionBefore(region, *before->getParent(), before->getIterator());
162 }
163 
164 /// Clone the blocks that belong to "region" before the given position in
165 /// another region "parent". The two regions must be different. The caller is
166 /// responsible for creating or updating the operation transferring flow of
167 /// control to the region and passing it the correct block arguments.
168 void PatternRewriter::cloneRegionBefore(Region &region, Region &parent,
169                                         Region::iterator before,
170                                         BlockAndValueMapping &mapping) {
171   region.cloneInto(&parent, before, mapping);
172 }
173 void PatternRewriter::cloneRegionBefore(Region &region, Region &parent,
174                                         Region::iterator before) {
175   BlockAndValueMapping mapping;
176   cloneRegionBefore(region, parent, before, mapping);
177 }
178 void PatternRewriter::cloneRegionBefore(Region &region, Block *before) {
179   cloneRegionBefore(region, *before->getParent(), before->getIterator());
180 }
181 
182 /// This method is used as the final notification hook for patterns that end
183 /// up modifying the pattern root in place, by changing its operands.  This is
184 /// a minor efficiency win (it avoids creating a new operation and removing
185 /// the old one) but also often allows simpler code in the client.
186 ///
187 /// The opsToRemoveIfDead list is an optional list of nodes that the rewriter
188 /// should remove if they are dead at this point.
189 ///
190 void PatternRewriter::updatedRootInPlace(Operation *op,
191                                          ValueRange valuesToRemoveIfDead) {
192   // Notify the rewriter subclass that we're about to replace this root.
193   notifyRootUpdated(op);
194 
195   // TODO: Process the valuesToRemoveIfDead list, removing things and calling
196   // the notifyOperationRemoved hook in the process.
197 }
198 
199 //===----------------------------------------------------------------------===//
200 // PatternMatcher implementation
201 //===----------------------------------------------------------------------===//
202 
203 RewritePatternMatcher::RewritePatternMatcher(
204     const OwningRewritePatternList &patterns) {
205   for (auto &pattern : patterns)
206     this->patterns.push_back(pattern.get());
207 
208   // Sort the patterns by benefit to simplify the matching logic.
209   std::stable_sort(this->patterns.begin(), this->patterns.end(),
210                    [](RewritePattern *l, RewritePattern *r) {
211                      return r->getBenefit() < l->getBenefit();
212                    });
213 }
214 
215 /// Try to match the given operation to a pattern and rewrite it.
216 bool RewritePatternMatcher::matchAndRewrite(Operation *op,
217                                             PatternRewriter &rewriter) {
218   for (auto *pattern : patterns) {
219     // Ignore patterns that are for the wrong root or are impossible to match.
220     if (pattern->getRootKind() != op->getName() ||
221         pattern->getBenefit().isImpossibleToMatch())
222       continue;
223 
224     // Try to match and rewrite this pattern. The patterns are sorted by
225     // benefit, so if we match we can immediately rewrite and return.
226     if (pattern->matchAndRewrite(op, rewriter))
227       return true;
228   }
229   return false;
230 }
231