1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/IR/BlockAndValueMapping.h"
11 
12 using namespace mlir;
13 
14 //===----------------------------------------------------------------------===//
15 // PatternBenefit
16 //===----------------------------------------------------------------------===//
17 
18 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
19   assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
20          "This pattern match benefit is too large to represent");
21 }
22 
23 unsigned short PatternBenefit::getBenefit() const {
24   assert(!isImpossibleToMatch() && "Pattern doesn't match");
25   return representation;
26 }
27 
28 //===----------------------------------------------------------------------===//
29 // Pattern
30 //===----------------------------------------------------------------------===//
31 
32 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
33                  MLIRContext *context)
34     : rootKind(OperationName(rootName, context)), benefit(benefit) {}
35 Pattern::Pattern(PatternBenefit benefit, MatchAnyOpTypeTag tag)
36     : benefit(benefit) {}
37 Pattern::Pattern(StringRef rootName, ArrayRef<StringRef> generatedNames,
38                  PatternBenefit benefit, MLIRContext *context)
39     : Pattern(rootName, benefit, context) {
40   generatedOps.reserve(generatedNames.size());
41   std::transform(generatedNames.begin(), generatedNames.end(),
42                  std::back_inserter(generatedOps), [context](StringRef name) {
43                    return OperationName(name, context);
44                  });
45 }
46 Pattern::Pattern(ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
47                  MLIRContext *context, MatchAnyOpTypeTag tag)
48     : Pattern(benefit, tag) {
49   generatedOps.reserve(generatedNames.size());
50   std::transform(generatedNames.begin(), generatedNames.end(),
51                  std::back_inserter(generatedOps), [context](StringRef name) {
52                    return OperationName(name, context);
53                  });
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // RewritePattern
58 //===----------------------------------------------------------------------===//
59 
60 void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
61   llvm_unreachable("need to implement either matchAndRewrite or one of the "
62                    "rewrite functions!");
63 }
64 
65 LogicalResult RewritePattern::match(Operation *op) const {
66   llvm_unreachable("need to implement either match or matchAndRewrite!");
67 }
68 
69 /// Out-of-line vtable anchor.
70 void RewritePattern::anchor() {}
71 
72 //===----------------------------------------------------------------------===//
73 // PDLValue
74 //===----------------------------------------------------------------------===//
75 
76 void PDLValue::print(raw_ostream &os) const {
77   if (!value) {
78     os << "<NULL-PDLValue>";
79     return;
80   }
81   switch (kind) {
82   case Kind::Attribute:
83     os << cast<Attribute>();
84     break;
85   case Kind::Operation:
86     os << *cast<Operation *>();
87     break;
88   case Kind::Type:
89     os << cast<Type>();
90     break;
91   case Kind::TypeRange:
92     llvm::interleaveComma(cast<TypeRange>(), os);
93     break;
94   case Kind::Value:
95     os << cast<Value>();
96     break;
97   case Kind::ValueRange:
98     llvm::interleaveComma(cast<ValueRange>(), os);
99     break;
100   }
101 }
102 
103 //===----------------------------------------------------------------------===//
104 // PDLPatternModule
105 //===----------------------------------------------------------------------===//
106 
107 void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
108   // Ignore the other module if it has no patterns.
109   if (!other.pdlModule)
110     return;
111   // Steal the other state if we have no patterns.
112   if (!pdlModule) {
113     constraintFunctions = std::move(other.constraintFunctions);
114     rewriteFunctions = std::move(other.rewriteFunctions);
115     pdlModule = std::move(other.pdlModule);
116     return;
117   }
118   // Steal the functions of the other module.
119   for (auto &it : constraintFunctions)
120     registerConstraintFunction(it.first(), std::move(it.second));
121   for (auto &it : rewriteFunctions)
122     registerRewriteFunction(it.first(), std::move(it.second));
123 
124   // Merge the pattern operations from the other module into this one.
125   Block *block = pdlModule->getBody();
126   block->getTerminator()->erase();
127   block->getOperations().splice(block->end(),
128                                 other.pdlModule->getBody()->getOperations());
129 }
130 
131 //===----------------------------------------------------------------------===//
132 // Function Registry
133 
134 void PDLPatternModule::registerConstraintFunction(
135     StringRef name, PDLConstraintFunction constraintFn) {
136   auto it = constraintFunctions.try_emplace(name, std::move(constraintFn));
137   (void)it;
138   assert(it.second &&
139          "constraint with the given name has already been registered");
140 }
141 
142 void PDLPatternModule::registerRewriteFunction(StringRef name,
143                                                PDLRewriteFunction rewriteFn) {
144   auto it = rewriteFunctions.try_emplace(name, std::move(rewriteFn));
145   (void)it;
146   assert(it.second && "native rewrite function with the given name has "
147                       "already been registered");
148 }
149 
150 //===----------------------------------------------------------------------===//
151 // RewriterBase
152 //===----------------------------------------------------------------------===//
153 
154 RewriterBase::~RewriterBase() {
155   // Out of line to provide a vtable anchor for the class.
156 }
157 
158 /// This method replaces the uses of the results of `op` with the values in
159 /// `newValues` when the provided `functor` returns true for a specific use.
160 /// The number of values in `newValues` is required to match the number of
161 /// results of `op`.
162 void RewriterBase::replaceOpWithIf(
163     Operation *op, ValueRange newValues, bool *allUsesReplaced,
164     llvm::unique_function<bool(OpOperand &) const> functor) {
165   assert(op->getNumResults() == newValues.size() &&
166          "incorrect number of values to replace operation");
167 
168   // Notify the rewriter subclass that we're about to replace this root.
169   notifyRootReplaced(op);
170 
171   // Replace each use of the results when the functor is true.
172   bool replacedAllUses = true;
173   for (auto it : llvm::zip(op->getResults(), newValues)) {
174     std::get<0>(it).replaceUsesWithIf(std::get<1>(it), functor);
175     replacedAllUses &= std::get<0>(it).use_empty();
176   }
177   if (allUsesReplaced)
178     *allUsesReplaced = replacedAllUses;
179 }
180 
181 /// This method replaces the uses of the results of `op` with the values in
182 /// `newValues` when a use is nested within the given `block`. The number of
183 /// values in `newValues` is required to match the number of results of `op`.
184 /// If all uses of this operation are replaced, the operation is erased.
185 void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues,
186                                         Block *block, bool *allUsesReplaced) {
187   replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) {
188     return block->getParentOp()->isProperAncestor(use.getOwner());
189   });
190 }
191 
192 /// This method replaces the results of the operation with the specified list of
193 /// values. The number of provided values must match the number of results of
194 /// the operation.
195 void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
196   // Notify the rewriter subclass that we're about to replace this root.
197   notifyRootReplaced(op);
198 
199   assert(op->getNumResults() == newValues.size() &&
200          "incorrect # of replacement values");
201   op->replaceAllUsesWith(newValues);
202 
203   notifyOperationRemoved(op);
204   op->erase();
205 }
206 
207 /// This method erases an operation that is known to have no uses. The uses of
208 /// the given operation *must* be known to be dead.
209 void RewriterBase::eraseOp(Operation *op) {
210   assert(op->use_empty() && "expected 'op' to have no uses");
211   notifyOperationRemoved(op);
212   op->erase();
213 }
214 
215 void RewriterBase::eraseBlock(Block *block) {
216   for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
217     assert(op.use_empty() && "expected 'op' to have no uses");
218     eraseOp(&op);
219   }
220   block->erase();
221 }
222 
223 /// Merge the operations of block 'source' into the end of block 'dest'.
224 /// 'source's predecessors must be empty or only contain 'dest`.
225 /// 'argValues' is used to replace the block arguments of 'source' after
226 /// merging.
227 void RewriterBase::mergeBlocks(Block *source, Block *dest,
228                                ValueRange argValues) {
229   assert(llvm::all_of(source->getPredecessors(),
230                       [dest](Block *succ) { return succ == dest; }) &&
231          "expected 'source' to have no predecessors or only 'dest'");
232   assert(argValues.size() == source->getNumArguments() &&
233          "incorrect # of argument replacement values");
234 
235   // Replace all of the successor arguments with the provided values.
236   for (auto it : llvm::zip(source->getArguments(), argValues))
237     std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
238 
239   // Splice the operations of the 'source' block into the 'dest' block and erase
240   // it.
241   dest->getOperations().splice(dest->end(), source->getOperations());
242   source->dropAllUses();
243   source->erase();
244 }
245 
246 // Merge the operations of block 'source' before the operation 'op'. Source
247 // block should not have existing predecessors or successors.
248 void RewriterBase::mergeBlockBefore(Block *source, Operation *op,
249                                     ValueRange argValues) {
250   assert(source->hasNoPredecessors() &&
251          "expected 'source' to have no predecessors");
252   assert(source->hasNoSuccessors() &&
253          "expected 'source' to have no successors");
254 
255   // Split the block containing 'op' into two, one containing all operations
256   // before 'op' (prologue) and another (epilogue) containing 'op' and all
257   // operations after it.
258   Block *prologue = op->getBlock();
259   Block *epilogue = splitBlock(prologue, op->getIterator());
260 
261   // Merge the source block at the end of the prologue.
262   mergeBlocks(source, prologue, argValues);
263 
264   // Merge the epilogue at the end the prologue.
265   mergeBlocks(epilogue, prologue);
266 }
267 
268 /// Split the operations starting at "before" (inclusive) out of the given
269 /// block into a new block, and return it.
270 Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
271   return block->splitBlock(before);
272 }
273 
274 /// 'op' and 'newOp' are known to have the same number of results, replace the
275 /// uses of op with uses of newOp
276 void RewriterBase::replaceOpWithResultsOfAnotherOp(Operation *op,
277                                                    Operation *newOp) {
278   assert(op->getNumResults() == newOp->getNumResults() &&
279          "replacement op doesn't match results of original op");
280   if (op->getNumResults() == 1)
281     return replaceOp(op, newOp->getResult(0));
282   return replaceOp(op, newOp->getResults());
283 }
284 
285 /// Move the blocks that belong to "region" before the given position in
286 /// another region.  The two regions must be different.  The caller is in
287 /// charge to update create the operation transferring the control flow to the
288 /// region and pass it the correct block arguments.
289 void RewriterBase::inlineRegionBefore(Region &region, Region &parent,
290                                       Region::iterator before) {
291   parent.getBlocks().splice(before, region.getBlocks());
292 }
293 void RewriterBase::inlineRegionBefore(Region &region, Block *before) {
294   inlineRegionBefore(region, *before->getParent(), before->getIterator());
295 }
296 
297 /// Clone the blocks that belong to "region" before the given position in
298 /// another region "parent". The two regions must be different. The caller is
299 /// responsible for creating or updating the operation transferring flow of
300 /// control to the region and passing it the correct block arguments.
301 void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
302                                      Region::iterator before,
303                                      BlockAndValueMapping &mapping) {
304   region.cloneInto(&parent, before, mapping);
305 }
306 void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
307                                      Region::iterator before) {
308   BlockAndValueMapping mapping;
309   cloneRegionBefore(region, parent, before, mapping);
310 }
311 void RewriterBase::cloneRegionBefore(Region &region, Block *before) {
312   cloneRegionBefore(region, *before->getParent(), before->getIterator());
313 }
314