1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/IR/Operation.h"
20 #include "mlir/IR/Value.h"
21 using namespace mlir;
22 
23 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
24   assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
25          "This pattern match benefit is too large to represent");
26 }
27 
28 unsigned short PatternBenefit::getBenefit() const {
29   assert(representation != ImpossibleToMatchSentinel &&
30          "Pattern doesn't match");
31   return representation;
32 }
33 
34 //===----------------------------------------------------------------------===//
35 // Pattern implementation
36 //===----------------------------------------------------------------------===//
37 
38 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
39                  MLIRContext *context)
40     : rootKind(OperationName(rootName, context)), benefit(benefit) {}
41 
42 // Out-of-line vtable anchor.
43 void Pattern::anchor() {}
44 
45 //===----------------------------------------------------------------------===//
46 // RewritePattern and PatternRewriter implementation
47 //===----------------------------------------------------------------------===//
48 
49 void RewritePattern::rewrite(Operation *op, std::unique_ptr<PatternState> state,
50                              PatternRewriter &rewriter) const {
51   rewrite(op, rewriter);
52 }
53 
54 void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
55   llvm_unreachable("need to implement either matchAndRewrite or one of the "
56                    "rewrite functions!");
57 }
58 
59 PatternMatchResult RewritePattern::match(Operation *op) const {
60   llvm_unreachable("need to implement either match or matchAndRewrite!");
61 }
62 
63 /// Patterns must specify the root operation name they match against, and can
64 /// also specify the benefit of the pattern matching. They can also specify the
65 /// names of operations that may be generated during a successful rewrite.
66 RewritePattern::RewritePattern(StringRef rootName,
67                                ArrayRef<StringRef> generatedNames,
68                                PatternBenefit benefit, MLIRContext *context)
69     : Pattern(rootName, benefit, context) {
70   generatedOps.reserve(generatedNames.size());
71   std::transform(generatedNames.begin(), generatedNames.end(),
72                  std::back_inserter(generatedOps), [context](StringRef name) {
73                    return OperationName(name, context);
74                  });
75 }
76 
77 PatternRewriter::~PatternRewriter() {
78   // Out of line to provide a vtable anchor for the class.
79 }
80 
81 /// This method performs the final replacement for a pattern, where the
82 /// results of the operation are updated to use the specified list of SSA
83 /// values.  In addition to replacing and removing the specified operation,
84 /// clients can specify a list of other nodes that this replacement may make
85 /// (perhaps transitively) dead.  If any of those ops are dead, this will
86 /// remove them as well.
87 void PatternRewriter::replaceOp(Operation *op, ArrayRef<Value *> newValues,
88                                 ArrayRef<Value *> valuesToRemoveIfDead) {
89   // Notify the rewriter subclass that we're about to replace this root.
90   notifyRootReplaced(op);
91 
92   assert(op->getNumResults() == newValues.size() &&
93          "incorrect # of replacement values");
94   op->replaceAllUsesWith(newValues);
95 
96   notifyOperationRemoved(op);
97   op->erase();
98 
99   // TODO: Process the valuesToRemoveIfDead list, removing things and calling
100   // the notifyOperationRemoved hook in the process.
101 }
102 
103 /// op and newOp are known to have the same number of results, replace the
104 /// uses of op with uses of newOp
105 void PatternRewriter::replaceOpWithResultsOfAnotherOp(
106     Operation *op, Operation *newOp, ArrayRef<Value *> valuesToRemoveIfDead) {
107   assert(op->getNumResults() == newOp->getNumResults() &&
108          "replacement op doesn't match results of original op");
109   if (op->getNumResults() == 1)
110     return replaceOp(op, newOp->getResult(0), valuesToRemoveIfDead);
111 
112   SmallVector<Value *, 8> newResults(newOp->getResults().begin(),
113                                      newOp->getResults().end());
114   return replaceOp(op, newResults, valuesToRemoveIfDead);
115 }
116 
117 /// Move the blocks that belong to "region" before the given position in
118 /// another region.  The two regions must be different.  The caller is in
119 /// charge to update create the operation transferring the control flow to the
120 /// region and pass it the correct block arguments.
121 void PatternRewriter::inlineRegionBefore(Region &region, Region &parent,
122                                          Region::iterator before) {
123   parent.getBlocks().splice(before, region.getBlocks());
124 }
125 void PatternRewriter::inlineRegionBefore(Region &region, Block *before) {
126   inlineRegionBefore(region, *before->getParent(), before->getIterator());
127 }
128 
129 /// This method is used as the final notification hook for patterns that end
130 /// up modifying the pattern root in place, by changing its operands.  This is
131 /// a minor efficiency win (it avoids creating a new operation and removing
132 /// the old one) but also often allows simpler code in the client.
133 ///
134 /// The opsToRemoveIfDead list is an optional list of nodes that the rewriter
135 /// should remove if they are dead at this point.
136 ///
137 void PatternRewriter::updatedRootInPlace(
138     Operation *op, ArrayRef<Value *> valuesToRemoveIfDead) {
139   // Notify the rewriter subclass that we're about to replace this root.
140   notifyRootUpdated(op);
141 
142   // TODO: Process the valuesToRemoveIfDead list, removing things and calling
143   // the notifyOperationRemoved hook in the process.
144 }
145 
146 //===----------------------------------------------------------------------===//
147 // PatternMatcher implementation
148 //===----------------------------------------------------------------------===//
149 
150 RewritePatternMatcher::RewritePatternMatcher(
151     const OwningRewritePatternList &patterns) {
152   for (auto &pattern : patterns)
153     this->patterns.push_back(pattern.get());
154 
155   // Sort the patterns by benefit to simplify the matching logic.
156   std::stable_sort(this->patterns.begin(), this->patterns.end(),
157                    [](RewritePattern *l, RewritePattern *r) {
158                      return r->getBenefit() < l->getBenefit();
159                    });
160 }
161 
162 /// Try to match the given operation to a pattern and rewrite it.
163 bool RewritePatternMatcher::matchAndRewrite(Operation *op,
164                                             PatternRewriter &rewriter) {
165   for (auto *pattern : patterns) {
166     // Ignore patterns that are for the wrong root or are impossible to match.
167     if (pattern->getRootKind() != op->getName() ||
168         pattern->getBenefit().isImpossibleToMatch())
169       continue;
170 
171     // Try to match and rewrite this pattern. The patterns are sorted by
172     // benefit, so if we match we can immediately rewrite and return.
173     if (pattern->matchAndRewrite(op, rewriter))
174       return true;
175   }
176   return false;
177 }
178