1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/PatternMatch.h"
10 #include "mlir/IR/BlockAndValueMapping.h"
11
12 using namespace mlir;
13
14 //===----------------------------------------------------------------------===//
15 // PatternBenefit
16 //===----------------------------------------------------------------------===//
17
PatternBenefit(unsigned benefit)18 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) {
19 assert(representation == benefit && benefit != ImpossibleToMatchSentinel &&
20 "This pattern match benefit is too large to represent");
21 }
22
getBenefit() const23 unsigned short PatternBenefit::getBenefit() const {
24 assert(!isImpossibleToMatch() && "Pattern doesn't match");
25 return representation;
26 }
27
28 //===----------------------------------------------------------------------===//
29 // Pattern
30 //===----------------------------------------------------------------------===//
31
32 //===----------------------------------------------------------------------===//
33 // OperationName Root Constructors
34
Pattern(StringRef rootName,PatternBenefit benefit,MLIRContext * context,ArrayRef<StringRef> generatedNames)35 Pattern::Pattern(StringRef rootName, PatternBenefit benefit,
36 MLIRContext *context, ArrayRef<StringRef> generatedNames)
37 : Pattern(OperationName(rootName, context).getAsOpaquePointer(),
38 RootKind::OperationName, generatedNames, benefit, context) {}
39
40 //===----------------------------------------------------------------------===//
41 // MatchAnyOpTypeTag Root Constructors
42
Pattern(MatchAnyOpTypeTag tag,PatternBenefit benefit,MLIRContext * context,ArrayRef<StringRef> generatedNames)43 Pattern::Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit,
44 MLIRContext *context, ArrayRef<StringRef> generatedNames)
45 : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {}
46
47 //===----------------------------------------------------------------------===//
48 // MatchInterfaceOpTypeTag Root Constructors
49
Pattern(MatchInterfaceOpTypeTag tag,TypeID interfaceID,PatternBenefit benefit,MLIRContext * context,ArrayRef<StringRef> generatedNames)50 Pattern::Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
51 PatternBenefit benefit, MLIRContext *context,
52 ArrayRef<StringRef> generatedNames)
53 : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID,
54 generatedNames, benefit, context) {}
55
56 //===----------------------------------------------------------------------===//
57 // MatchTraitOpTypeTag Root Constructors
58
Pattern(MatchTraitOpTypeTag tag,TypeID traitID,PatternBenefit benefit,MLIRContext * context,ArrayRef<StringRef> generatedNames)59 Pattern::Pattern(MatchTraitOpTypeTag tag, TypeID traitID,
60 PatternBenefit benefit, MLIRContext *context,
61 ArrayRef<StringRef> generatedNames)
62 : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames,
63 benefit, context) {}
64
65 //===----------------------------------------------------------------------===//
66 // General Constructors
67
Pattern(const void * rootValue,RootKind rootKind,ArrayRef<StringRef> generatedNames,PatternBenefit benefit,MLIRContext * context)68 Pattern::Pattern(const void *rootValue, RootKind rootKind,
69 ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
70 MLIRContext *context)
71 : rootValue(rootValue), rootKind(rootKind), benefit(benefit),
72 contextAndHasBoundedRecursion(context, false) {
73 if (generatedNames.empty())
74 return;
75 generatedOps.reserve(generatedNames.size());
76 std::transform(generatedNames.begin(), generatedNames.end(),
77 std::back_inserter(generatedOps), [context](StringRef name) {
78 return OperationName(name, context);
79 });
80 }
81
82 //===----------------------------------------------------------------------===//
83 // RewritePattern
84 //===----------------------------------------------------------------------===//
85
rewrite(Operation * op,PatternRewriter & rewriter) const86 void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const {
87 llvm_unreachable("need to implement either matchAndRewrite or one of the "
88 "rewrite functions!");
89 }
90
match(Operation * op) const91 LogicalResult RewritePattern::match(Operation *op) const {
92 llvm_unreachable("need to implement either match or matchAndRewrite!");
93 }
94
95 /// Out-of-line vtable anchor.
anchor()96 void RewritePattern::anchor() {}
97
98 //===----------------------------------------------------------------------===//
99 // PDLValue
100 //===----------------------------------------------------------------------===//
101
print(raw_ostream & os) const102 void PDLValue::print(raw_ostream &os) const {
103 if (!value) {
104 os << "<NULL-PDLValue>";
105 return;
106 }
107 switch (kind) {
108 case Kind::Attribute:
109 os << cast<Attribute>();
110 break;
111 case Kind::Operation:
112 os << *cast<Operation *>();
113 break;
114 case Kind::Type:
115 os << cast<Type>();
116 break;
117 case Kind::TypeRange:
118 llvm::interleaveComma(cast<TypeRange>(), os);
119 break;
120 case Kind::Value:
121 os << cast<Value>();
122 break;
123 case Kind::ValueRange:
124 llvm::interleaveComma(cast<ValueRange>(), os);
125 break;
126 }
127 }
128
print(raw_ostream & os,Kind kind)129 void PDLValue::print(raw_ostream &os, Kind kind) {
130 switch (kind) {
131 case Kind::Attribute:
132 os << "Attribute";
133 break;
134 case Kind::Operation:
135 os << "Operation";
136 break;
137 case Kind::Type:
138 os << "Type";
139 break;
140 case Kind::TypeRange:
141 os << "TypeRange";
142 break;
143 case Kind::Value:
144 os << "Value";
145 break;
146 case Kind::ValueRange:
147 os << "ValueRange";
148 break;
149 }
150 }
151
152 //===----------------------------------------------------------------------===//
153 // PDLPatternModule
154 //===----------------------------------------------------------------------===//
155
mergeIn(PDLPatternModule && other)156 void PDLPatternModule::mergeIn(PDLPatternModule &&other) {
157 // Ignore the other module if it has no patterns.
158 if (!other.pdlModule)
159 return;
160
161 // Steal the functions of the other module.
162 for (auto &it : other.constraintFunctions)
163 registerConstraintFunction(it.first(), std::move(it.second));
164 for (auto &it : other.rewriteFunctions)
165 registerRewriteFunction(it.first(), std::move(it.second));
166
167 // Steal the other state if we have no patterns.
168 if (!pdlModule) {
169 pdlModule = std::move(other.pdlModule);
170 return;
171 }
172
173 // Merge the pattern operations from the other module into this one.
174 Block *block = pdlModule->getBody();
175 block->getOperations().splice(block->end(),
176 other.pdlModule->getBody()->getOperations());
177 }
178
179 //===----------------------------------------------------------------------===//
180 // Function Registry
181
registerConstraintFunction(StringRef name,PDLConstraintFunction constraintFn)182 void PDLPatternModule::registerConstraintFunction(
183 StringRef name, PDLConstraintFunction constraintFn) {
184 // TODO: Is it possible to diagnose when `name` is already registered to
185 // a function that is not equivalent to `constraintFn`?
186 // Allow existing mappings in the case multiple patterns depend on the same
187 // constraint.
188 constraintFunctions.try_emplace(name, std::move(constraintFn));
189 }
190
registerRewriteFunction(StringRef name,PDLRewriteFunction rewriteFn)191 void PDLPatternModule::registerRewriteFunction(StringRef name,
192 PDLRewriteFunction rewriteFn) {
193 // TODO: Is it possible to diagnose when `name` is already registered to
194 // a function that is not equivalent to `rewriteFn`?
195 // Allow existing mappings in the case multiple patterns depend on the same
196 // rewrite.
197 rewriteFunctions.try_emplace(name, std::move(rewriteFn));
198 }
199
200 //===----------------------------------------------------------------------===//
201 // RewriterBase
202 //===----------------------------------------------------------------------===//
203
~RewriterBase()204 RewriterBase::~RewriterBase() {
205 // Out of line to provide a vtable anchor for the class.
206 }
207
208 /// This method replaces the uses of the results of `op` with the values in
209 /// `newValues` when the provided `functor` returns true for a specific use.
210 /// The number of values in `newValues` is required to match the number of
211 /// results of `op`.
replaceOpWithIf(Operation * op,ValueRange newValues,bool * allUsesReplaced,llvm::unique_function<bool (OpOperand &)const> functor)212 void RewriterBase::replaceOpWithIf(
213 Operation *op, ValueRange newValues, bool *allUsesReplaced,
214 llvm::unique_function<bool(OpOperand &) const> functor) {
215 assert(op->getNumResults() == newValues.size() &&
216 "incorrect number of values to replace operation");
217
218 // Notify the rewriter subclass that we're about to replace this root.
219 notifyRootReplaced(op);
220
221 // Replace each use of the results when the functor is true.
222 bool replacedAllUses = true;
223 for (auto it : llvm::zip(op->getResults(), newValues)) {
224 std::get<0>(it).replaceUsesWithIf(std::get<1>(it), functor);
225 replacedAllUses &= std::get<0>(it).use_empty();
226 }
227 if (allUsesReplaced)
228 *allUsesReplaced = replacedAllUses;
229 }
230
231 /// This method replaces the uses of the results of `op` with the values in
232 /// `newValues` when a use is nested within the given `block`. The number of
233 /// values in `newValues` is required to match the number of results of `op`.
234 /// If all uses of this operation are replaced, the operation is erased.
replaceOpWithinBlock(Operation * op,ValueRange newValues,Block * block,bool * allUsesReplaced)235 void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues,
236 Block *block, bool *allUsesReplaced) {
237 replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) {
238 return block->getParentOp()->isProperAncestor(use.getOwner());
239 });
240 }
241
242 /// This method replaces the results of the operation with the specified list of
243 /// values. The number of provided values must match the number of results of
244 /// the operation.
replaceOp(Operation * op,ValueRange newValues)245 void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
246 // Notify the rewriter subclass that we're about to replace this root.
247 notifyRootReplaced(op);
248
249 assert(op->getNumResults() == newValues.size() &&
250 "incorrect # of replacement values");
251 op->replaceAllUsesWith(newValues);
252
253 notifyOperationRemoved(op);
254 op->erase();
255 }
256
257 /// This method erases an operation that is known to have no uses. The uses of
258 /// the given operation *must* be known to be dead.
eraseOp(Operation * op)259 void RewriterBase::eraseOp(Operation *op) {
260 assert(op->use_empty() && "expected 'op' to have no uses");
261 notifyOperationRemoved(op);
262 op->erase();
263 }
264
eraseBlock(Block * block)265 void RewriterBase::eraseBlock(Block *block) {
266 for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) {
267 assert(op.use_empty() && "expected 'op' to have no uses");
268 eraseOp(&op);
269 }
270 block->erase();
271 }
272
273 /// Merge the operations of block 'source' into the end of block 'dest'.
274 /// 'source's predecessors must be empty or only contain 'dest`.
275 /// 'argValues' is used to replace the block arguments of 'source' after
276 /// merging.
mergeBlocks(Block * source,Block * dest,ValueRange argValues)277 void RewriterBase::mergeBlocks(Block *source, Block *dest,
278 ValueRange argValues) {
279 assert(llvm::all_of(source->getPredecessors(),
280 [dest](Block *succ) { return succ == dest; }) &&
281 "expected 'source' to have no predecessors or only 'dest'");
282 assert(argValues.size() == source->getNumArguments() &&
283 "incorrect # of argument replacement values");
284
285 // Replace all of the successor arguments with the provided values.
286 for (auto it : llvm::zip(source->getArguments(), argValues))
287 std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
288
289 // Splice the operations of the 'source' block into the 'dest' block and erase
290 // it.
291 dest->getOperations().splice(dest->end(), source->getOperations());
292 source->dropAllUses();
293 source->erase();
294 }
295
296 // Merge the operations of block 'source' before the operation 'op'. Source
297 // block should not have existing predecessors or successors.
mergeBlockBefore(Block * source,Operation * op,ValueRange argValues)298 void RewriterBase::mergeBlockBefore(Block *source, Operation *op,
299 ValueRange argValues) {
300 assert(source->hasNoPredecessors() &&
301 "expected 'source' to have no predecessors");
302 assert(source->hasNoSuccessors() &&
303 "expected 'source' to have no successors");
304
305 // Split the block containing 'op' into two, one containing all operations
306 // before 'op' (prologue) and another (epilogue) containing 'op' and all
307 // operations after it.
308 Block *prologue = op->getBlock();
309 Block *epilogue = splitBlock(prologue, op->getIterator());
310
311 // Merge the source block at the end of the prologue.
312 mergeBlocks(source, prologue, argValues);
313
314 // Merge the epilogue at the end the prologue.
315 mergeBlocks(epilogue, prologue);
316 }
317
318 /// Split the operations starting at "before" (inclusive) out of the given
319 /// block into a new block, and return it.
splitBlock(Block * block,Block::iterator before)320 Block *RewriterBase::splitBlock(Block *block, Block::iterator before) {
321 return block->splitBlock(before);
322 }
323
324 /// 'op' and 'newOp' are known to have the same number of results, replace the
325 /// uses of op with uses of newOp
replaceOpWithResultsOfAnotherOp(Operation * op,Operation * newOp)326 void RewriterBase::replaceOpWithResultsOfAnotherOp(Operation *op,
327 Operation *newOp) {
328 assert(op->getNumResults() == newOp->getNumResults() &&
329 "replacement op doesn't match results of original op");
330 if (op->getNumResults() == 1)
331 return replaceOp(op, newOp->getResult(0));
332 return replaceOp(op, newOp->getResults());
333 }
334
335 /// Move the blocks that belong to "region" before the given position in
336 /// another region. The two regions must be different. The caller is in
337 /// charge to update create the operation transferring the control flow to the
338 /// region and pass it the correct block arguments.
inlineRegionBefore(Region & region,Region & parent,Region::iterator before)339 void RewriterBase::inlineRegionBefore(Region ®ion, Region &parent,
340 Region::iterator before) {
341 parent.getBlocks().splice(before, region.getBlocks());
342 }
inlineRegionBefore(Region & region,Block * before)343 void RewriterBase::inlineRegionBefore(Region ®ion, Block *before) {
344 inlineRegionBefore(region, *before->getParent(), before->getIterator());
345 }
346
347 /// Clone the blocks that belong to "region" before the given position in
348 /// another region "parent". The two regions must be different. The caller is
349 /// responsible for creating or updating the operation transferring flow of
350 /// control to the region and passing it the correct block arguments.
cloneRegionBefore(Region & region,Region & parent,Region::iterator before,BlockAndValueMapping & mapping)351 void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent,
352 Region::iterator before,
353 BlockAndValueMapping &mapping) {
354 region.cloneInto(&parent, before, mapping);
355 }
cloneRegionBefore(Region & region,Region & parent,Region::iterator before)356 void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent,
357 Region::iterator before) {
358 BlockAndValueMapping mapping;
359 cloneRegionBefore(region, parent, before, mapping);
360 }
cloneRegionBefore(Region & region,Block * before)361 void RewriterBase::cloneRegionBefore(Region ®ion, Block *before) {
362 cloneRegionBefore(region, *before->getParent(), before->getIterator());
363 }
364