1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/BlockAndValueMapping.h"
20 #include "mlir/IR/Operation.h"
21 #include "mlir/IR/Value.h"
22 using namespace mlir;
23 
24 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
25   assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
26          "This pattern match benefit is too large to represent");
27 }
28 
29 unsigned short PatternBenefit::getBenefit() const {
30   assert(representation != ImpossibleToMatchSentinel &&
31          "Pattern doesn't match");
32   return representation;
33 }
34 
35 //===----------------------------------------------------------------------===//
36 // Pattern implementation
37 //===----------------------------------------------------------------------===//
38 
39 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
40                  MLIRContext *context)
41     : rootKind(OperationName(rootName, context)), benefit(benefit) {}
42 
43 // Out-of-line vtable anchor.
44 void Pattern::anchor() {}
45 
46 //===----------------------------------------------------------------------===//
47 // RewritePattern and PatternRewriter implementation
48 //===----------------------------------------------------------------------===//
49 
50 void RewritePattern::rewrite(Operation *op, std::unique_ptr<PatternState> state,
51                              PatternRewriter &rewriter) const {
52   rewrite(op, rewriter);
53 }
54 
55 void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
56   llvm_unreachable("need to implement either matchAndRewrite or one of the "
57                    "rewrite functions!");
58 }
59 
60 PatternMatchResult RewritePattern::match(Operation *op) const {
61   llvm_unreachable("need to implement either match or matchAndRewrite!");
62 }
63 
64 /// Patterns must specify the root operation name they match against, and can
65 /// also specify the benefit of the pattern matching. They can also specify the
66 /// names of operations that may be generated during a successful rewrite.
67 RewritePattern::RewritePattern(StringRef rootName,
68                                ArrayRef<StringRef> generatedNames,
69                                PatternBenefit benefit, MLIRContext *context)
70     : Pattern(rootName, benefit, context) {
71   generatedOps.reserve(generatedNames.size());
72   std::transform(generatedNames.begin(), generatedNames.end(),
73                  std::back_inserter(generatedOps), [context](StringRef name) {
74                    return OperationName(name, context);
75                  });
76 }
77 
78 PatternRewriter::~PatternRewriter() {
79   // Out of line to provide a vtable anchor for the class.
80 }
81 
82 /// This method performs the final replacement for a pattern, where the
83 /// results of the operation are updated to use the specified list of SSA
84 /// values.  In addition to replacing and removing the specified operation,
85 /// clients can specify a list of other nodes that this replacement may make
86 /// (perhaps transitively) dead.  If any of those ops are dead, this will
87 /// remove them as well.
88 void PatternRewriter::replaceOp(Operation *op, ArrayRef<Value *> newValues,
89                                 ArrayRef<Value *> valuesToRemoveIfDead) {
90   // Notify the rewriter subclass that we're about to replace this root.
91   notifyRootReplaced(op);
92 
93   assert(op->getNumResults() == newValues.size() &&
94          "incorrect # of replacement values");
95   op->replaceAllUsesWith(newValues);
96 
97   notifyOperationRemoved(op);
98   op->erase();
99 
100   // TODO: Process the valuesToRemoveIfDead list, removing things and calling
101   // the notifyOperationRemoved hook in the process.
102 }
103 
104 /// op and newOp are known to have the same number of results, replace the
105 /// uses of op with uses of newOp
106 void PatternRewriter::replaceOpWithResultsOfAnotherOp(
107     Operation *op, Operation *newOp, ArrayRef<Value *> valuesToRemoveIfDead) {
108   assert(op->getNumResults() == newOp->getNumResults() &&
109          "replacement op doesn't match results of original op");
110   if (op->getNumResults() == 1)
111     return replaceOp(op, newOp->getResult(0), valuesToRemoveIfDead);
112 
113   SmallVector<Value *, 8> newResults(newOp->getResults().begin(),
114                                      newOp->getResults().end());
115   return replaceOp(op, newResults, valuesToRemoveIfDead);
116 }
117 
118 /// Move the blocks that belong to "region" before the given position in
119 /// another region.  The two regions must be different.  The caller is in
120 /// charge to update create the operation transferring the control flow to the
121 /// region and pass it the correct block arguments.
122 void PatternRewriter::inlineRegionBefore(Region &region, Region &parent,
123                                          Region::iterator before) {
124   parent.getBlocks().splice(before, region.getBlocks());
125 }
126 void PatternRewriter::inlineRegionBefore(Region &region, Block *before) {
127   inlineRegionBefore(region, *before->getParent(), before->getIterator());
128 }
129 
130 /// Clone the blocks that belong to "region" before the given position in
131 /// another region "parent". The two regions must be different. The caller is
132 /// responsible for creating or updating the operation transferring flow of
133 /// control to the region and passing it the correct block arguments.
134 void PatternRewriter::cloneRegionBefore(Region &region, Region &parent,
135                                         Region::iterator before,
136                                         BlockAndValueMapping &mapping) {
137   region.cloneInto(&parent, before, mapping);
138 }
139 void PatternRewriter::cloneRegionBefore(Region &region, Region &parent,
140                                         Region::iterator before) {
141   BlockAndValueMapping mapping;
142   cloneRegionBefore(region, parent, before, mapping);
143 }
144 void PatternRewriter::cloneRegionBefore(Region &region, Block *before) {
145   cloneRegionBefore(region, *before->getParent(), before->getIterator());
146 }
147 
148 /// This method is used as the final notification hook for patterns that end
149 /// up modifying the pattern root in place, by changing its operands.  This is
150 /// a minor efficiency win (it avoids creating a new operation and removing
151 /// the old one) but also often allows simpler code in the client.
152 ///
153 /// The opsToRemoveIfDead list is an optional list of nodes that the rewriter
154 /// should remove if they are dead at this point.
155 ///
156 void PatternRewriter::updatedRootInPlace(
157     Operation *op, ArrayRef<Value *> valuesToRemoveIfDead) {
158   // Notify the rewriter subclass that we're about to replace this root.
159   notifyRootUpdated(op);
160 
161   // TODO: Process the valuesToRemoveIfDead list, removing things and calling
162   // the notifyOperationRemoved hook in the process.
163 }
164 
165 //===----------------------------------------------------------------------===//
166 // PatternMatcher implementation
167 //===----------------------------------------------------------------------===//
168 
169 RewritePatternMatcher::RewritePatternMatcher(
170     const OwningRewritePatternList &patterns) {
171   for (auto &pattern : patterns)
172     this->patterns.push_back(pattern.get());
173 
174   // Sort the patterns by benefit to simplify the matching logic.
175   std::stable_sort(this->patterns.begin(), this->patterns.end(),
176                    [](RewritePattern *l, RewritePattern *r) {
177                      return r->getBenefit() < l->getBenefit();
178                    });
179 }
180 
181 /// Try to match the given operation to a pattern and rewrite it.
182 bool RewritePatternMatcher::matchAndRewrite(Operation *op,
183                                             PatternRewriter &rewriter) {
184   for (auto *pattern : patterns) {
185     // Ignore patterns that are for the wrong root or are impossible to match.
186     if (pattern->getRootKind() != op->getName() ||
187         pattern->getBenefit().isImpossibleToMatch())
188       continue;
189 
190     // Try to match and rewrite this pattern. The patterns are sorted by
191     // benefit, so if we match we can immediately rewrite and return.
192     if (pattern->matchAndRewrite(op, rewriter))
193       return true;
194   }
195   return false;
196 }
197