1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/PatternMatch.h" 10 #include "mlir/IR/BlockAndValueMapping.h" 11 #include "mlir/IR/Operation.h" 12 #include "mlir/IR/Value.h" 13 14 using namespace mlir; 15 16 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) { 17 assert(representation == benefit && benefit != ImpossibleToMatchSentinel && 18 "This pattern match benefit is too large to represent"); 19 } 20 21 unsigned short PatternBenefit::getBenefit() const { 22 assert(representation != ImpossibleToMatchSentinel && 23 "Pattern doesn't match"); 24 return representation; 25 } 26 27 //===----------------------------------------------------------------------===// 28 // Pattern implementation 29 //===----------------------------------------------------------------------===// 30 31 Pattern::Pattern(StringRef rootName, PatternBenefit benefit, 32 MLIRContext *context) 33 : rootKind(OperationName(rootName, context)), benefit(benefit) {} 34 35 // Out-of-line vtable anchor. 36 void Pattern::anchor() {} 37 38 //===----------------------------------------------------------------------===// 39 // RewritePattern and PatternRewriter implementation 40 //===----------------------------------------------------------------------===// 41 42 void RewritePattern::rewrite(Operation *op, std::unique_ptr<PatternState> state, 43 PatternRewriter &rewriter) const { 44 rewrite(op, rewriter); 45 } 46 47 void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const { 48 llvm_unreachable("need to implement either matchAndRewrite or one of the " 49 "rewrite functions!"); 50 } 51 52 PatternMatchResult RewritePattern::match(Operation *op) const { 53 llvm_unreachable("need to implement either match or matchAndRewrite!"); 54 } 55 56 /// Patterns must specify the root operation name they match against, and can 57 /// also specify the benefit of the pattern matching. They can also specify the 58 /// names of operations that may be generated during a successful rewrite. 59 RewritePattern::RewritePattern(StringRef rootName, 60 ArrayRef<StringRef> generatedNames, 61 PatternBenefit benefit, MLIRContext *context) 62 : Pattern(rootName, benefit, context) { 63 generatedOps.reserve(generatedNames.size()); 64 std::transform(generatedNames.begin(), generatedNames.end(), 65 std::back_inserter(generatedOps), [context](StringRef name) { 66 return OperationName(name, context); 67 }); 68 } 69 70 PatternRewriter::~PatternRewriter() { 71 // Out of line to provide a vtable anchor for the class. 72 } 73 74 /// This method performs the final replacement for a pattern, where the 75 /// results of the operation are updated to use the specified list of SSA 76 /// values. 77 void PatternRewriter::replaceOp(Operation *op, ValueRange newValues) { 78 // Notify the rewriter subclass that we're about to replace this root. 79 notifyRootReplaced(op); 80 81 assert(op->getNumResults() == newValues.size() && 82 "incorrect # of replacement values"); 83 op->replaceAllUsesWith(newValues); 84 85 notifyOperationRemoved(op); 86 op->erase(); 87 } 88 89 /// This method erases an operation that is known to have no uses. The uses of 90 /// the given operation *must* be known to be dead. 91 void PatternRewriter::eraseOp(Operation *op) { 92 assert(op->use_empty() && "expected 'op' to have no uses"); 93 notifyOperationRemoved(op); 94 op->erase(); 95 } 96 97 /// Merge the operations of block 'source' into the end of block 'dest'. 98 /// 'source's predecessors must be empty or only contain 'dest`. 99 /// 'argValues' is used to replace the block arguments of 'source' after 100 /// merging. 101 void PatternRewriter::mergeBlocks(Block *source, Block *dest, 102 ValueRange argValues) { 103 assert(llvm::all_of(source->getPredecessors(), 104 [dest](Block *succ) { return succ == dest; }) && 105 "expected 'source' to have no predecessors or only 'dest'"); 106 assert(argValues.size() == source->getNumArguments() && 107 "incorrect # of argument replacement values"); 108 109 // Replace all of the successor arguments with the provided values. 110 for (auto it : llvm::zip(source->getArguments(), argValues)) 111 std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); 112 113 // Splice the operations of the 'source' block into the 'dest' block and erase 114 // it. 115 dest->getOperations().splice(dest->end(), source->getOperations()); 116 source->dropAllUses(); 117 source->erase(); 118 } 119 120 /// Split the operations starting at "before" (inclusive) out of the given 121 /// block into a new block, and return it. 122 Block *PatternRewriter::splitBlock(Block *block, Block::iterator before) { 123 return block->splitBlock(before); 124 } 125 126 /// 'op' and 'newOp' are known to have the same number of results, replace the 127 /// uses of op with uses of newOp 128 void PatternRewriter::replaceOpWithResultsOfAnotherOp(Operation *op, 129 Operation *newOp) { 130 assert(op->getNumResults() == newOp->getNumResults() && 131 "replacement op doesn't match results of original op"); 132 if (op->getNumResults() == 1) 133 return replaceOp(op, newOp->getResult(0)); 134 return replaceOp(op, newOp->getResults()); 135 } 136 137 /// Move the blocks that belong to "region" before the given position in 138 /// another region. The two regions must be different. The caller is in 139 /// charge to update create the operation transferring the control flow to the 140 /// region and pass it the correct block arguments. 141 void PatternRewriter::inlineRegionBefore(Region ®ion, Region &parent, 142 Region::iterator before) { 143 parent.getBlocks().splice(before, region.getBlocks()); 144 } 145 void PatternRewriter::inlineRegionBefore(Region ®ion, Block *before) { 146 inlineRegionBefore(region, *before->getParent(), before->getIterator()); 147 } 148 149 /// Clone the blocks that belong to "region" before the given position in 150 /// another region "parent". The two regions must be different. The caller is 151 /// responsible for creating or updating the operation transferring flow of 152 /// control to the region and passing it the correct block arguments. 153 void PatternRewriter::cloneRegionBefore(Region ®ion, Region &parent, 154 Region::iterator before, 155 BlockAndValueMapping &mapping) { 156 region.cloneInto(&parent, before, mapping); 157 } 158 void PatternRewriter::cloneRegionBefore(Region ®ion, Region &parent, 159 Region::iterator before) { 160 BlockAndValueMapping mapping; 161 cloneRegionBefore(region, parent, before, mapping); 162 } 163 void PatternRewriter::cloneRegionBefore(Region ®ion, Block *before) { 164 cloneRegionBefore(region, *before->getParent(), before->getIterator()); 165 } 166 167 //===----------------------------------------------------------------------===// 168 // PatternMatcher implementation 169 //===----------------------------------------------------------------------===// 170 171 RewritePatternMatcher::RewritePatternMatcher( 172 const OwningRewritePatternList &patterns) { 173 for (auto &pattern : patterns) 174 this->patterns.push_back(pattern.get()); 175 176 // Sort the patterns by benefit to simplify the matching logic. 177 std::stable_sort(this->patterns.begin(), this->patterns.end(), 178 [](RewritePattern *l, RewritePattern *r) { 179 return r->getBenefit() < l->getBenefit(); 180 }); 181 } 182 183 /// Try to match the given operation to a pattern and rewrite it. 184 bool RewritePatternMatcher::matchAndRewrite(Operation *op, 185 PatternRewriter &rewriter) { 186 for (auto *pattern : patterns) { 187 // Ignore patterns that are for the wrong root or are impossible to match. 188 if (pattern->getRootKind() != op->getName() || 189 pattern->getBenefit().isImpossibleToMatch()) 190 continue; 191 192 // Try to match and rewrite this pattern. The patterns are sorted by 193 // benefit, so if we match we can immediately rewrite and return. 194 if (pattern->matchAndRewrite(op, rewriter)) 195 return true; 196 } 197 return false; 198 } 199