1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/IR/BlockAndValueMapping.h"
11 
12 using namespace mlir;
13 
14 //===----------------------------------------------------------------------===//
15 // PatternBenefit
16 //===----------------------------------------------------------------------===//
17 
18 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
19   assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
20          "This pattern match benefit is too large to represent");
21 }
22 
23 unsigned short PatternBenefit::getBenefit() const {
24   assert(!isImpossibleToMatch() && "Pattern doesn't match");
25   return representation;
26 }
27 
28 //===----------------------------------------------------------------------===//
29 // Pattern
30 //===----------------------------------------------------------------------===//
31 
32 //===----------------------------------------------------------------------===//
33 // OperationName Root Constructors
34 
35 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
36                  MLIRContext *context, ArrayRef<StringRef> generatedNames)
37     : Pattern(OperationName(rootName, context).getAsOpaquePointer(),
38               RootKind::OperationName, generatedNames, benefit, context) {}
39 
40 //===----------------------------------------------------------------------===//
41 // MatchAnyOpTypeTag Root Constructors
42 
43 Pattern::Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit,
44                  MLIRContext *context, ArrayRef<StringRef> generatedNames)
45     : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {}
46 
47 //===----------------------------------------------------------------------===//
48 // MatchInterfaceOpTypeTag Root Constructors
49 
50 Pattern::Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
51                  PatternBenefit benefit, MLIRContext *context,
52                  ArrayRef<StringRef> generatedNames)
53     : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID,
54               generatedNames, benefit, context) {}
55 
56 //===----------------------------------------------------------------------===//
57 // MatchTraitOpTypeTag Root Constructors
58 
59 Pattern::Pattern(MatchTraitOpTypeTag tag, TypeID traitID,
60                  PatternBenefit benefit, MLIRContext *context,
61                  ArrayRef<StringRef> generatedNames)
62     : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames,
63               benefit, context) {}
64 
65 //===----------------------------------------------------------------------===//
66 // General Constructors
67 
68 Pattern::Pattern(const void *rootValue, RootKind rootKind,
69                  ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
70                  MLIRContext *context)
71     : rootValue(rootValue), rootKind(rootKind), benefit(benefit),
72       contextAndHasBoundedRecursion(context, false) {
73   if (generatedNames.empty())
74     return;
75   generatedOps.reserve(generatedNames.size());
76   std::transform(generatedNames.begin(), generatedNames.end(),
77                  std::back_inserter(generatedOps), [context](StringRef name) {
78                    return OperationName(name, context);
79                  });
80 }
81 
82 //===----------------------------------------------------------------------===//
83 // RewritePattern
84 //===----------------------------------------------------------------------===//
85 
86 void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
87   llvm_unreachable("need to implement either matchAndRewrite or one of the "
88                    "rewrite functions!");
89 }
90 
91 LogicalResult RewritePattern::match(Operation *op) const {
92   llvm_unreachable("need to implement either match or matchAndRewrite!");
93 }
94 
95 /// Out-of-line vtable anchor.
96 void RewritePattern::anchor() {}
97 
98 //===----------------------------------------------------------------------===//
99 // PDLValue
100 //===----------------------------------------------------------------------===//
101 
102 void PDLValue::print(raw_ostream &os) const {
103   if (!value) {
104     os << "<NULL-PDLValue>";
105     return;
106   }
107   switch (kind) {
108   case Kind::Attribute:
109     os << cast<Attribute>();
110     break;
111   case Kind::Operation:
112     os << *cast<Operation *>();
113     break;
114   case Kind::Type:
115     os << cast<Type>();
116     break;
117   case Kind::TypeRange:
118     llvm::interleaveComma(cast<TypeRange>(), os);
119     break;
120   case Kind::Value:
121     os << cast<Value>();
122     break;
123   case Kind::ValueRange:
124     llvm::interleaveComma(cast<ValueRange>(), os);
125     break;
126   }
127 }
128 
129 void PDLValue::print(raw_ostream &os, Kind kind) {
130   switch (kind) {
131   case Kind::Attribute:
132     os << "Attribute";
133     break;
134   case Kind::Operation:
135     os << "Operation";
136     break;
137   case Kind::Type:
138     os << "Type";
139     break;
140   case Kind::TypeRange:
141     os << "TypeRange";
142     break;
143   case Kind::Value:
144     os << "Value";
145     break;
146   case Kind::ValueRange:
147     os << "ValueRange";
148     break;
149   }
150 }
151 
152 //===----------------------------------------------------------------------===//
153 // PDLPatternModule
154 //===----------------------------------------------------------------------===//
155 
156 void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
157   // Ignore the other module if it has no patterns.
158   if (!other.pdlModule)
159     return;
160 
161   // Steal the functions of the other module.
162   for (auto &it : other.constraintFunctions)
163     registerConstraintFunction(it.first(), std::move(it.second));
164   for (auto &it : other.rewriteFunctions)
165     registerRewriteFunction(it.first(), std::move(it.second));
166 
167   // Steal the other state if we have no patterns.
168   if (!pdlModule) {
169     pdlModule = std::move(other.pdlModule);
170     return;
171   }
172 
173   // Merge the pattern operations from the other module into this one.
174   Block *block = pdlModule->getBody();
175   block->getOperations().splice(block->end(),
176                                 other.pdlModule->getBody()->getOperations());
177 }
178 
179 //===----------------------------------------------------------------------===//
180 // Function Registry
181 
182 void PDLPatternModule::registerConstraintFunction(
183     StringRef name, PDLConstraintFunction constraintFn) {
184   // TODO: Is it possible to diagnose when `name` is already registered to
185   // a function that is not equivalent to `constraintFn`?
186   // Allow existing mappings in the case multiple patterns depend on the same
187   // constraint.
188   constraintFunctions.try_emplace(name, std::move(constraintFn));
189 }
190 
191 void PDLPatternModule::registerRewriteFunction(StringRef name,
192                                                PDLRewriteFunction rewriteFn) {
193   // TODO: Is it possible to diagnose when `name` is already registered to
194   // a function that is not equivalent to `rewriteFn`?
195   // Allow existing mappings in the case multiple patterns depend on the same
196   // rewrite.
197   rewriteFunctions.try_emplace(name, std::move(rewriteFn));
198 }
199 
200 //===----------------------------------------------------------------------===//
201 // RewriterBase
202 //===----------------------------------------------------------------------===//
203 
204 RewriterBase::~RewriterBase() {
205   // Out of line to provide a vtable anchor for the class.
206 }
207 
208 /// This method replaces the uses of the results of `op` with the values in
209 /// `newValues` when the provided `functor` returns true for a specific use.
210 /// The number of values in `newValues` is required to match the number of
211 /// results of `op`.
212 void RewriterBase::replaceOpWithIf(
213     Operation *op, ValueRange newValues, bool *allUsesReplaced,
214     llvm::unique_function<bool(OpOperand &) const> functor) {
215   assert(op->getNumResults() == newValues.size() &&
216          "incorrect number of values to replace operation");
217 
218   // Notify the rewriter subclass that we're about to replace this root.
219   notifyRootReplaced(op);
220 
221   // Replace each use of the results when the functor is true.
222   bool replacedAllUses = true;
223   for (auto it : llvm::zip(op->getResults(), newValues)) {
224     std::get<0>(it).replaceUsesWithIf(std::get<1>(it), functor);
225     replacedAllUses &= std::get<0>(it).use_empty();
226   }
227   if (allUsesReplaced)
228     *allUsesReplaced = replacedAllUses;
229 }
230 
231 /// This method replaces the uses of the results of `op` with the values in
232 /// `newValues` when a use is nested within the given `block`. The number of
233 /// values in `newValues` is required to match the number of results of `op`.
234 /// If all uses of this operation are replaced, the operation is erased.
235 void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues,
236                                         Block *block, bool *allUsesReplaced) {
237   replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) {
238     return block->getParentOp()->isProperAncestor(use.getOwner());
239   });
240 }
241 
242 /// This method replaces the results of the operation with the specified list of
243 /// values. The number of provided values must match the number of results of
244 /// the operation.
245 void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
246   // Notify the rewriter subclass that we're about to replace this root.
247   notifyRootReplaced(op);
248 
249   assert(op->getNumResults() == newValues.size() &&
250          "incorrect # of replacement values");
251   op->replaceAllUsesWith(newValues);
252 
253   notifyOperationRemoved(op);
254   op->erase();
255 }
256 
257 /// This method erases an operation that is known to have no uses. The uses of
258 /// the given operation *must* be known to be dead.
259 void RewriterBase::eraseOp(Operation *op) {
260   assert(op->use_empty() && "expected 'op' to have no uses");
261   notifyOperationRemoved(op);
262   op->erase();
263 }
264 
265 void RewriterBase::eraseBlock(Block *block) {
266   for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
267     assert(op.use_empty() && "expected 'op' to have no uses");
268     eraseOp(&op);
269   }
270   block->erase();
271 }
272 
273 /// Merge the operations of block 'source' into the end of block 'dest'.
274 /// 'source's predecessors must be empty or only contain 'dest`.
275 /// 'argValues' is used to replace the block arguments of 'source' after
276 /// merging.
277 void RewriterBase::mergeBlocks(Block *source, Block *dest,
278                                ValueRange argValues) {
279   assert(llvm::all_of(source->getPredecessors(),
280                       [dest](Block *succ) { return succ == dest; }) &&
281          "expected 'source' to have no predecessors or only 'dest'");
282   assert(argValues.size() == source->getNumArguments() &&
283          "incorrect # of argument replacement values");
284 
285   // Replace all of the successor arguments with the provided values.
286   for (auto it : llvm::zip(source->getArguments(), argValues))
287     std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
288 
289   // Splice the operations of the 'source' block into the 'dest' block and erase
290   // it.
291   dest->getOperations().splice(dest->end(), source->getOperations());
292   source->dropAllUses();
293   source->erase();
294 }
295 
296 // Merge the operations of block 'source' before the operation 'op'. Source
297 // block should not have existing predecessors or successors.
298 void RewriterBase::mergeBlockBefore(Block *source, Operation *op,
299                                     ValueRange argValues) {
300   assert(source->hasNoPredecessors() &&
301          "expected 'source' to have no predecessors");
302   assert(source->hasNoSuccessors() &&
303          "expected 'source' to have no successors");
304 
305   // Split the block containing 'op' into two, one containing all operations
306   // before 'op' (prologue) and another (epilogue) containing 'op' and all
307   // operations after it.
308   Block *prologue = op->getBlock();
309   Block *epilogue = splitBlock(prologue, op->getIterator());
310 
311   // Merge the source block at the end of the prologue.
312   mergeBlocks(source, prologue, argValues);
313 
314   // Merge the epilogue at the end the prologue.
315   mergeBlocks(epilogue, prologue);
316 }
317 
318 /// Split the operations starting at "before" (inclusive) out of the given
319 /// block into a new block, and return it.
320 Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
321   return block->splitBlock(before);
322 }
323 
324 /// 'op' and 'newOp' are known to have the same number of results, replace the
325 /// uses of op with uses of newOp
326 void RewriterBase::replaceOpWithResultsOfAnotherOp(Operation *op,
327                                                    Operation *newOp) {
328   assert(op->getNumResults() == newOp->getNumResults() &&
329          "replacement op doesn't match results of original op");
330   if (op->getNumResults() == 1)
331     return replaceOp(op, newOp->getResult(0));
332   return replaceOp(op, newOp->getResults());
333 }
334 
335 /// Move the blocks that belong to "region" before the given position in
336 /// another region.  The two regions must be different.  The caller is in
337 /// charge to update create the operation transferring the control flow to the
338 /// region and pass it the correct block arguments.
339 void RewriterBase::inlineRegionBefore(Region &region, Region &parent,
340                                       Region::iterator before) {
341   parent.getBlocks().splice(before, region.getBlocks());
342 }
343 void RewriterBase::inlineRegionBefore(Region &region, Block *before) {
344   inlineRegionBefore(region, *before->getParent(), before->getIterator());
345 }
346 
347 /// Clone the blocks that belong to "region" before the given position in
348 /// another region "parent". The two regions must be different. The caller is
349 /// responsible for creating or updating the operation transferring flow of
350 /// control to the region and passing it the correct block arguments.
351 void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
352                                      Region::iterator before,
353                                      BlockAndValueMapping &mapping) {
354   region.cloneInto(&parent, before, mapping);
355 }
356 void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
357                                      Region::iterator before) {
358   BlockAndValueMapping mapping;
359   cloneRegionBefore(region, parent, before, mapping);
360 }
361 void RewriterBase::cloneRegionBefore(Region &region, Block *before) {
362   cloneRegionBefore(region, *before->getParent(), before->getIterator());
363 }
364