1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/PatternMatch.h" 10 #include "mlir/IR/BlockAndValueMapping.h" 11 12 using namespace mlir; 13 14 //===----------------------------------------------------------------------===// 15 // PatternBenefit 16 //===----------------------------------------------------------------------===// 17 18 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) { 19 assert(representation == benefit && benefit != ImpossibleToMatchSentinel && 20 "This pattern match benefit is too large to represent"); 21 } 22 23 unsigned short PatternBenefit::getBenefit() const { 24 assert(!isImpossibleToMatch() && "Pattern doesn't match"); 25 return representation; 26 } 27 28 //===----------------------------------------------------------------------===// 29 // Pattern 30 //===----------------------------------------------------------------------===// 31 32 //===----------------------------------------------------------------------===// 33 // OperationName Root Constructors 34 35 Pattern::Pattern(StringRef rootName, PatternBenefit benefit, 36 MLIRContext *context, ArrayRef<StringRef> generatedNames) 37 : Pattern(OperationName(rootName, context).getAsOpaquePointer(), 38 RootKind::OperationName, generatedNames, benefit, context) {} 39 40 //===----------------------------------------------------------------------===// 41 // MatchAnyOpTypeTag Root Constructors 42 43 Pattern::Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, 44 MLIRContext *context, ArrayRef<StringRef> generatedNames) 45 : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {} 46 47 //===----------------------------------------------------------------------===// 48 // MatchInterfaceOpTypeTag Root Constructors 49 50 Pattern::Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID, 51 PatternBenefit benefit, MLIRContext *context, 52 ArrayRef<StringRef> generatedNames) 53 : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID, 54 generatedNames, benefit, context) {} 55 56 //===----------------------------------------------------------------------===// 57 // MatchTraitOpTypeTag Root Constructors 58 59 Pattern::Pattern(MatchTraitOpTypeTag tag, TypeID traitID, 60 PatternBenefit benefit, MLIRContext *context, 61 ArrayRef<StringRef> generatedNames) 62 : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames, 63 benefit, context) {} 64 65 //===----------------------------------------------------------------------===// 66 // General Constructors 67 68 Pattern::Pattern(const void *rootValue, RootKind rootKind, 69 ArrayRef<StringRef> generatedNames, PatternBenefit benefit, 70 MLIRContext *context) 71 : rootValue(rootValue), rootKind(rootKind), benefit(benefit), 72 contextAndHasBoundedRecursion(context, false) { 73 if (generatedNames.empty()) 74 return; 75 generatedOps.reserve(generatedNames.size()); 76 std::transform(generatedNames.begin(), generatedNames.end(), 77 std::back_inserter(generatedOps), [context](StringRef name) { 78 return OperationName(name, context); 79 }); 80 } 81 82 //===----------------------------------------------------------------------===// 83 // RewritePattern 84 //===----------------------------------------------------------------------===// 85 86 void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const { 87 llvm_unreachable("need to implement either matchAndRewrite or one of the " 88 "rewrite functions!"); 89 } 90 91 LogicalResult RewritePattern::match(Operation *op) const { 92 llvm_unreachable("need to implement either match or matchAndRewrite!"); 93 } 94 95 /// Out-of-line vtable anchor. 96 void RewritePattern::anchor() {} 97 98 //===----------------------------------------------------------------------===// 99 // PDLValue 100 //===----------------------------------------------------------------------===// 101 102 void PDLValue::print(raw_ostream &os) const { 103 if (!value) { 104 os << "<NULL-PDLValue>"; 105 return; 106 } 107 switch (kind) { 108 case Kind::Attribute: 109 os << cast<Attribute>(); 110 break; 111 case Kind::Operation: 112 os << *cast<Operation *>(); 113 break; 114 case Kind::Type: 115 os << cast<Type>(); 116 break; 117 case Kind::TypeRange: 118 llvm::interleaveComma(cast<TypeRange>(), os); 119 break; 120 case Kind::Value: 121 os << cast<Value>(); 122 break; 123 case Kind::ValueRange: 124 llvm::interleaveComma(cast<ValueRange>(), os); 125 break; 126 } 127 } 128 129 //===----------------------------------------------------------------------===// 130 // PDLPatternModule 131 //===----------------------------------------------------------------------===// 132 133 void PDLPatternModule::mergeIn(PDLPatternModule &&other) { 134 // Ignore the other module if it has no patterns. 135 if (!other.pdlModule) 136 return; 137 // Steal the other state if we have no patterns. 138 if (!pdlModule) { 139 constraintFunctions = std::move(other.constraintFunctions); 140 rewriteFunctions = std::move(other.rewriteFunctions); 141 pdlModule = std::move(other.pdlModule); 142 return; 143 } 144 // Steal the functions of the other module. 145 for (auto &it : constraintFunctions) 146 registerConstraintFunction(it.first(), std::move(it.second)); 147 for (auto &it : rewriteFunctions) 148 registerRewriteFunction(it.first(), std::move(it.second)); 149 150 // Merge the pattern operations from the other module into this one. 151 Block *block = pdlModule->getBody(); 152 block->getTerminator()->erase(); 153 block->getOperations().splice(block->end(), 154 other.pdlModule->getBody()->getOperations()); 155 } 156 157 //===----------------------------------------------------------------------===// 158 // Function Registry 159 160 void PDLPatternModule::registerConstraintFunction( 161 StringRef name, PDLConstraintFunction constraintFn) { 162 auto it = constraintFunctions.try_emplace(name, std::move(constraintFn)); 163 (void)it; 164 assert(it.second && 165 "constraint with the given name has already been registered"); 166 } 167 168 void PDLPatternModule::registerRewriteFunction(StringRef name, 169 PDLRewriteFunction rewriteFn) { 170 auto it = rewriteFunctions.try_emplace(name, std::move(rewriteFn)); 171 (void)it; 172 assert(it.second && "native rewrite function with the given name has " 173 "already been registered"); 174 } 175 176 //===----------------------------------------------------------------------===// 177 // RewriterBase 178 //===----------------------------------------------------------------------===// 179 180 RewriterBase::~RewriterBase() { 181 // Out of line to provide a vtable anchor for the class. 182 } 183 184 /// This method replaces the uses of the results of `op` with the values in 185 /// `newValues` when the provided `functor` returns true for a specific use. 186 /// The number of values in `newValues` is required to match the number of 187 /// results of `op`. 188 void RewriterBase::replaceOpWithIf( 189 Operation *op, ValueRange newValues, bool *allUsesReplaced, 190 llvm::unique_function<bool(OpOperand &) const> functor) { 191 assert(op->getNumResults() == newValues.size() && 192 "incorrect number of values to replace operation"); 193 194 // Notify the rewriter subclass that we're about to replace this root. 195 notifyRootReplaced(op); 196 197 // Replace each use of the results when the functor is true. 198 bool replacedAllUses = true; 199 for (auto it : llvm::zip(op->getResults(), newValues)) { 200 std::get<0>(it).replaceUsesWithIf(std::get<1>(it), functor); 201 replacedAllUses &= std::get<0>(it).use_empty(); 202 } 203 if (allUsesReplaced) 204 *allUsesReplaced = replacedAllUses; 205 } 206 207 /// This method replaces the uses of the results of `op` with the values in 208 /// `newValues` when a use is nested within the given `block`. The number of 209 /// values in `newValues` is required to match the number of results of `op`. 210 /// If all uses of this operation are replaced, the operation is erased. 211 void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues, 212 Block *block, bool *allUsesReplaced) { 213 replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) { 214 return block->getParentOp()->isProperAncestor(use.getOwner()); 215 }); 216 } 217 218 /// This method replaces the results of the operation with the specified list of 219 /// values. The number of provided values must match the number of results of 220 /// the operation. 221 void RewriterBase::replaceOp(Operation *op, ValueRange newValues) { 222 // Notify the rewriter subclass that we're about to replace this root. 223 notifyRootReplaced(op); 224 225 assert(op->getNumResults() == newValues.size() && 226 "incorrect # of replacement values"); 227 op->replaceAllUsesWith(newValues); 228 229 notifyOperationRemoved(op); 230 op->erase(); 231 } 232 233 /// This method erases an operation that is known to have no uses. The uses of 234 /// the given operation *must* be known to be dead. 235 void RewriterBase::eraseOp(Operation *op) { 236 assert(op->use_empty() && "expected 'op' to have no uses"); 237 notifyOperationRemoved(op); 238 op->erase(); 239 } 240 241 void RewriterBase::eraseBlock(Block *block) { 242 for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) { 243 assert(op.use_empty() && "expected 'op' to have no uses"); 244 eraseOp(&op); 245 } 246 block->erase(); 247 } 248 249 /// Merge the operations of block 'source' into the end of block 'dest'. 250 /// 'source's predecessors must be empty or only contain 'dest`. 251 /// 'argValues' is used to replace the block arguments of 'source' after 252 /// merging. 253 void RewriterBase::mergeBlocks(Block *source, Block *dest, 254 ValueRange argValues) { 255 assert(llvm::all_of(source->getPredecessors(), 256 [dest](Block *succ) { return succ == dest; }) && 257 "expected 'source' to have no predecessors or only 'dest'"); 258 assert(argValues.size() == source->getNumArguments() && 259 "incorrect # of argument replacement values"); 260 261 // Replace all of the successor arguments with the provided values. 262 for (auto it : llvm::zip(source->getArguments(), argValues)) 263 std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); 264 265 // Splice the operations of the 'source' block into the 'dest' block and erase 266 // it. 267 dest->getOperations().splice(dest->end(), source->getOperations()); 268 source->dropAllUses(); 269 source->erase(); 270 } 271 272 // Merge the operations of block 'source' before the operation 'op'. Source 273 // block should not have existing predecessors or successors. 274 void RewriterBase::mergeBlockBefore(Block *source, Operation *op, 275 ValueRange argValues) { 276 assert(source->hasNoPredecessors() && 277 "expected 'source' to have no predecessors"); 278 assert(source->hasNoSuccessors() && 279 "expected 'source' to have no successors"); 280 281 // Split the block containing 'op' into two, one containing all operations 282 // before 'op' (prologue) and another (epilogue) containing 'op' and all 283 // operations after it. 284 Block *prologue = op->getBlock(); 285 Block *epilogue = splitBlock(prologue, op->getIterator()); 286 287 // Merge the source block at the end of the prologue. 288 mergeBlocks(source, prologue, argValues); 289 290 // Merge the epilogue at the end the prologue. 291 mergeBlocks(epilogue, prologue); 292 } 293 294 /// Split the operations starting at "before" (inclusive) out of the given 295 /// block into a new block, and return it. 296 Block *RewriterBase::splitBlock(Block *block, Block::iterator before) { 297 return block->splitBlock(before); 298 } 299 300 /// 'op' and 'newOp' are known to have the same number of results, replace the 301 /// uses of op with uses of newOp 302 void RewriterBase::replaceOpWithResultsOfAnotherOp(Operation *op, 303 Operation *newOp) { 304 assert(op->getNumResults() == newOp->getNumResults() && 305 "replacement op doesn't match results of original op"); 306 if (op->getNumResults() == 1) 307 return replaceOp(op, newOp->getResult(0)); 308 return replaceOp(op, newOp->getResults()); 309 } 310 311 /// Move the blocks that belong to "region" before the given position in 312 /// another region. The two regions must be different. The caller is in 313 /// charge to update create the operation transferring the control flow to the 314 /// region and pass it the correct block arguments. 315 void RewriterBase::inlineRegionBefore(Region ®ion, Region &parent, 316 Region::iterator before) { 317 parent.getBlocks().splice(before, region.getBlocks()); 318 } 319 void RewriterBase::inlineRegionBefore(Region ®ion, Block *before) { 320 inlineRegionBefore(region, *before->getParent(), before->getIterator()); 321 } 322 323 /// Clone the blocks that belong to "region" before the given position in 324 /// another region "parent". The two regions must be different. The caller is 325 /// responsible for creating or updating the operation transferring flow of 326 /// control to the region and passing it the correct block arguments. 327 void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent, 328 Region::iterator before, 329 BlockAndValueMapping &mapping) { 330 region.cloneInto(&parent, before, mapping); 331 } 332 void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent, 333 Region::iterator before) { 334 BlockAndValueMapping mapping; 335 cloneRegionBefore(region, parent, before, mapping); 336 } 337 void RewriterBase::cloneRegionBefore(Region ®ion, Block *before) { 338 cloneRegionBefore(region, *before->getParent(), before->getIterator()); 339 } 340