1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/BlockAndValueMapping.h"
20 #include "mlir/IR/Operation.h"
21 #include "mlir/IR/Value.h"
22 using namespace mlir;
23 
24 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
25   assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
26          "This pattern match benefit is too large to represent");
27 }
28 
29 unsigned short PatternBenefit::getBenefit() const {
30   assert(representation != ImpossibleToMatchSentinel &&
31          "Pattern doesn't match");
32   return representation;
33 }
34 
35 //===----------------------------------------------------------------------===//
36 // Pattern implementation
37 //===----------------------------------------------------------------------===//
38 
39 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
40                  MLIRContext *context)
41     : rootKind(OperationName(rootName, context)), benefit(benefit) {}
42 
43 // Out-of-line vtable anchor.
44 void Pattern::anchor() {}
45 
46 //===----------------------------------------------------------------------===//
47 // RewritePattern and PatternRewriter implementation
48 //===----------------------------------------------------------------------===//
49 
50 void RewritePattern::rewrite(Operation *op, std::unique_ptr<PatternState> state,
51                              PatternRewriter &rewriter) const {
52   rewrite(op, rewriter);
53 }
54 
55 void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
56   llvm_unreachable("need to implement either matchAndRewrite or one of the "
57                    "rewrite functions!");
58 }
59 
60 PatternMatchResult RewritePattern::match(Operation *op) const {
61   llvm_unreachable("need to implement either match or matchAndRewrite!");
62 }
63 
64 /// Patterns must specify the root operation name they match against, and can
65 /// also specify the benefit of the pattern matching. They can also specify the
66 /// names of operations that may be generated during a successful rewrite.
67 RewritePattern::RewritePattern(StringRef rootName,
68                                ArrayRef<StringRef> generatedNames,
69                                PatternBenefit benefit, MLIRContext *context)
70     : Pattern(rootName, benefit, context) {
71   generatedOps.reserve(generatedNames.size());
72   std::transform(generatedNames.begin(), generatedNames.end(),
73                  std::back_inserter(generatedOps), [context](StringRef name) {
74                    return OperationName(name, context);
75                  });
76 }
77 
78 PatternRewriter::~PatternRewriter() {
79   // Out of line to provide a vtable anchor for the class.
80 }
81 
82 /// This method performs the final replacement for a pattern, where the
83 /// results of the operation are updated to use the specified list of SSA
84 /// values.  In addition to replacing and removing the specified operation,
85 /// clients can specify a list of other nodes that this replacement may make
86 /// (perhaps transitively) dead.  If any of those ops are dead, this will
87 /// remove them as well.
88 void PatternRewriter::replaceOp(Operation *op, ArrayRef<Value *> newValues,
89                                 ArrayRef<Value *> valuesToRemoveIfDead) {
90   // Notify the rewriter subclass that we're about to replace this root.
91   notifyRootReplaced(op);
92 
93   assert(op->getNumResults() == newValues.size() &&
94          "incorrect # of replacement values");
95   op->replaceAllUsesWith(newValues);
96 
97   notifyOperationRemoved(op);
98   op->erase();
99 
100   // TODO: Process the valuesToRemoveIfDead list, removing things and calling
101   // the notifyOperationRemoved hook in the process.
102 }
103 
104 /// This method erases an operation that is known to have no uses. The uses of
105 /// the given operation *must* be known to be dead.
106 void PatternRewriter::eraseOp(Operation *op) {
107   assert(op->use_empty() && "expected 'op' to have no uses");
108   notifyOperationRemoved(op);
109   op->erase();
110 }
111 
112 /// op and newOp are known to have the same number of results, replace the
113 /// uses of op with uses of newOp
114 void PatternRewriter::replaceOpWithResultsOfAnotherOp(
115     Operation *op, Operation *newOp, ArrayRef<Value *> valuesToRemoveIfDead) {
116   assert(op->getNumResults() == newOp->getNumResults() &&
117          "replacement op doesn't match results of original op");
118   if (op->getNumResults() == 1)
119     return replaceOp(op, newOp->getResult(0), valuesToRemoveIfDead);
120 
121   SmallVector<Value *, 8> newResults(newOp->getResults().begin(),
122                                      newOp->getResults().end());
123   return replaceOp(op, newResults, valuesToRemoveIfDead);
124 }
125 
126 /// Move the blocks that belong to "region" before the given position in
127 /// another region.  The two regions must be different.  The caller is in
128 /// charge to update create the operation transferring the control flow to the
129 /// region and pass it the correct block arguments.
130 void PatternRewriter::inlineRegionBefore(Region &region, Region &parent,
131                                          Region::iterator before) {
132   parent.getBlocks().splice(before, region.getBlocks());
133 }
134 void PatternRewriter::inlineRegionBefore(Region &region, Block *before) {
135   inlineRegionBefore(region, *before->getParent(), before->getIterator());
136 }
137 
138 /// Clone the blocks that belong to "region" before the given position in
139 /// another region "parent". The two regions must be different. The caller is
140 /// responsible for creating or updating the operation transferring flow of
141 /// control to the region and passing it the correct block arguments.
142 void PatternRewriter::cloneRegionBefore(Region &region, Region &parent,
143                                         Region::iterator before,
144                                         BlockAndValueMapping &mapping) {
145   region.cloneInto(&parent, before, mapping);
146 }
147 void PatternRewriter::cloneRegionBefore(Region &region, Region &parent,
148                                         Region::iterator before) {
149   BlockAndValueMapping mapping;
150   cloneRegionBefore(region, parent, before, mapping);
151 }
152 void PatternRewriter::cloneRegionBefore(Region &region, Block *before) {
153   cloneRegionBefore(region, *before->getParent(), before->getIterator());
154 }
155 
156 /// This method is used as the final notification hook for patterns that end
157 /// up modifying the pattern root in place, by changing its operands.  This is
158 /// a minor efficiency win (it avoids creating a new operation and removing
159 /// the old one) but also often allows simpler code in the client.
160 ///
161 /// The opsToRemoveIfDead list is an optional list of nodes that the rewriter
162 /// should remove if they are dead at this point.
163 ///
164 void PatternRewriter::updatedRootInPlace(
165     Operation *op, ArrayRef<Value *> valuesToRemoveIfDead) {
166   // Notify the rewriter subclass that we're about to replace this root.
167   notifyRootUpdated(op);
168 
169   // TODO: Process the valuesToRemoveIfDead list, removing things and calling
170   // the notifyOperationRemoved hook in the process.
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // PatternMatcher implementation
175 //===----------------------------------------------------------------------===//
176 
177 RewritePatternMatcher::RewritePatternMatcher(
178     const OwningRewritePatternList &patterns) {
179   for (auto &pattern : patterns)
180     this->patterns.push_back(pattern.get());
181 
182   // Sort the patterns by benefit to simplify the matching logic.
183   std::stable_sort(this->patterns.begin(), this->patterns.end(),
184                    [](RewritePattern *l, RewritePattern *r) {
185                      return r->getBenefit() < l->getBenefit();
186                    });
187 }
188 
189 /// Try to match the given operation to a pattern and rewrite it.
190 bool RewritePatternMatcher::matchAndRewrite(Operation *op,
191                                             PatternRewriter &rewriter) {
192   for (auto *pattern : patterns) {
193     // Ignore patterns that are for the wrong root or are impossible to match.
194     if (pattern->getRootKind() != op->getName() ||
195         pattern->getBenefit().isImpossibleToMatch())
196       continue;
197 
198     // Try to match and rewrite this pattern. The patterns are sorted by
199     // benefit, so if we match we can immediately rewrite and return.
200     if (pattern->matchAndRewrite(op, rewriter))
201       return true;
202   }
203   return false;
204 }
205