1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/IR/BlockAndValueMapping.h"
11 
12 using namespace mlir;
13 
14 //===----------------------------------------------------------------------===//
15 // PatternBenefit
16 //===----------------------------------------------------------------------===//
17 
18 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
19   assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
20          "This pattern match benefit is too large to represent");
21 }
22 
23 unsigned short PatternBenefit::getBenefit() const {
24   assert(!isImpossibleToMatch() && "Pattern doesn't match");
25   return representation;
26 }
27 
28 //===----------------------------------------------------------------------===//
29 // Pattern
30 //===----------------------------------------------------------------------===//
31 
32 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
33                  MLIRContext *context)
34     : rootKind(OperationName(rootName, context)), benefit(benefit) {}
35 Pattern::Pattern(PatternBenefit benefit, MatchAnyOpTypeTag tag)
36     : benefit(benefit) {}
37 Pattern::Pattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
38                  PatternBenefit benefit, MLIRContext *context)
39     : Pattern(rootName, benefit, context) {
40   generatedOps.reserve(generatedNames.size());
41   std::transform(generatedNames.begin(), generatedNames.end(),
42                  std::back_inserter(generatedOps), [context](StringRef name) {
43                    return OperationName(name, context);
44                  });
45 }
46 Pattern::Pattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
47                  MLIRContext *context, MatchAnyOpTypeTag tag)
48     : Pattern(benefit, tag) {
49   generatedOps.reserve(generatedNames.size());
50   std::transform(generatedNames.begin(), generatedNames.end(),
51                  std::back_inserter(generatedOps), [context](StringRef name) {
52                    return OperationName(name, context);
53                  });
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // RewritePattern
58 //===----------------------------------------------------------------------===//
59 
60 void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
61   llvm_unreachable("need to implement either matchAndRewrite or one of the "
62                    "rewrite functions!");
63 }
64 
65 LogicalResult RewritePattern::match(Operation *op) const {
66   llvm_unreachable("need to implement either match or matchAndRewrite!");
67 }
68 
69 /// Out-of-line vtable anchor.
70 void RewritePattern::anchor() {}
71 
72 //===----------------------------------------------------------------------===//
73 // PatternRewriter
74 //===----------------------------------------------------------------------===//
75 
76 PatternRewriter::~PatternRewriter() {
77   // Out of line to provide a vtable anchor for the class.
78 }
79 
80 /// This method performs the final replacement for a pattern, where the
81 /// results of the operation are updated to use the specified list of SSA
82 /// values.
83 void PatternRewriter::replaceOp(Operation *op, ValueRange newValues) {
84   // Notify the rewriter subclass that we're about to replace this root.
85   notifyRootReplaced(op);
86 
87   assert(op->getNumResults() == newValues.size() &&
88          "incorrect # of replacement values");
89   op->replaceAllUsesWith(newValues);
90 
91   notifyOperationRemoved(op);
92   op->erase();
93 }
94 
95 /// This method erases an operation that is known to have no uses. The uses of
96 /// the given operation *must* be known to be dead.
97 void PatternRewriter::eraseOp(Operation *op) {
98   assert(op->use_empty() && "expected 'op' to have no uses");
99   notifyOperationRemoved(op);
100   op->erase();
101 }
102 
103 void PatternRewriter::eraseBlock(Block *block) {
104   for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
105     assert(op.use_empty() && "expected 'op' to have no uses");
106     eraseOp(&op);
107   }
108   block->erase();
109 }
110 
111 /// Merge the operations of block 'source' into the end of block 'dest'.
112 /// 'source's predecessors must be empty or only contain 'dest`.
113 /// 'argValues' is used to replace the block arguments of 'source' after
114 /// merging.
115 void PatternRewriter::mergeBlocks(Block *source, Block *dest,
116                                   ValueRange argValues) {
117   assert(llvm::all_of(source->getPredecessors(),
118                       [dest](Block *succ) { return succ == dest; }) &&
119          "expected 'source' to have no predecessors or only 'dest'");
120   assert(argValues.size() == source->getNumArguments() &&
121          "incorrect # of argument replacement values");
122 
123   // Replace all of the successor arguments with the provided values.
124   for (auto it : llvm::zip(source->getArguments(), argValues))
125     std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
126 
127   // Splice the operations of the 'source' block into the 'dest' block and erase
128   // it.
129   dest->getOperations().splice(dest->end(), source->getOperations());
130   source->dropAllUses();
131   source->erase();
132 }
133 
134 // Merge the operations of block 'source' before the operation 'op'. Source
135 // block should not have existing predecessors or successors.
136 void PatternRewriter::mergeBlockBefore(Block *source, Operation *op,
137                                        ValueRange argValues) {
138   assert(source->hasNoPredecessors() &&
139          "expected 'source' to have no predecessors");
140   assert(source->hasNoSuccessors() &&
141          "expected 'source' to have no successors");
142 
143   // Split the block containing 'op' into two, one containing all operations
144   // before 'op' (prologue) and another (epilogue) containing 'op' and all
145   // operations after it.
146   Block *prologue = op->getBlock();
147   Block *epilogue = splitBlock(prologue, op->getIterator());
148 
149   // Merge the source block at the end of the prologue.
150   mergeBlocks(source, prologue, argValues);
151 
152   // Merge the epilogue at the end the prologue.
153   mergeBlocks(epilogue, prologue);
154 }
155 
156 /// Split the operations starting at "before" (inclusive) out of the given
157 /// block into a new block, and return it.
158 Block *PatternRewriter::splitBlock(Block *block, Block::iterator before) {
159   return block->splitBlock(before);
160 }
161 
162 /// 'op' and 'newOp' are known to have the same number of results, replace the
163 /// uses of op with uses of newOp
164 void PatternRewriter::replaceOpWithResultsOfAnotherOp(Operation *op,
165                                                       Operation *newOp) {
166   assert(op->getNumResults() == newOp->getNumResults() &&
167          "replacement op doesn't match results of original op");
168   if (op->getNumResults() == 1)
169     return replaceOp(op, newOp->getResult(0));
170   return replaceOp(op, newOp->getResults());
171 }
172 
173 /// Move the blocks that belong to "region" before the given position in
174 /// another region.  The two regions must be different.  The caller is in
175 /// charge to update create the operation transferring the control flow to the
176 /// region and pass it the correct block arguments.
177 void PatternRewriter::inlineRegionBefore(Region &region, Region &parent,
178                                          Region::iterator before) {
179   parent.getBlocks().splice(before, region.getBlocks());
180 }
181 void PatternRewriter::inlineRegionBefore(Region &region, Block *before) {
182   inlineRegionBefore(region, *before->getParent(), before->getIterator());
183 }
184 
185 /// Clone the blocks that belong to "region" before the given position in
186 /// another region "parent". The two regions must be different. The caller is
187 /// responsible for creating or updating the operation transferring flow of
188 /// control to the region and passing it the correct block arguments.
189 void PatternRewriter::cloneRegionBefore(Region &region, Region &parent,
190                                         Region::iterator before,
191                                         BlockAndValueMapping &mapping) {
192   region.cloneInto(&parent, before, mapping);
193 }
194 void PatternRewriter::cloneRegionBefore(Region &region, Region &parent,
195                                         Region::iterator before) {
196   BlockAndValueMapping mapping;
197   cloneRegionBefore(region, parent, before, mapping);
198 }
199 void PatternRewriter::cloneRegionBefore(Region &region, Block *before) {
200   cloneRegionBefore(region, *before->getParent(), before->getIterator());
201 }
202 
203