1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/PatternMatch.h" 10 #include "mlir/IR/BlockAndValueMapping.h" 11 12 using namespace mlir; 13 14 //===----------------------------------------------------------------------===// 15 // PatternBenefit 16 //===----------------------------------------------------------------------===// 17 18 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) { 19 assert(representation == benefit && benefit != ImpossibleToMatchSentinel && 20 "This pattern match benefit is too large to represent"); 21 } 22 23 unsigned short PatternBenefit::getBenefit() const { 24 assert(!isImpossibleToMatch() && "Pattern doesn't match"); 25 return representation; 26 } 27 28 //===----------------------------------------------------------------------===// 29 // Pattern 30 //===----------------------------------------------------------------------===// 31 32 //===----------------------------------------------------------------------===// 33 // OperationName Root Constructors 34 35 Pattern::Pattern(StringRef rootName, PatternBenefit benefit, 36 MLIRContext *context, ArrayRef<StringRef> generatedNames) 37 : Pattern(OperationName(rootName, context).getAsOpaquePointer(), 38 RootKind::OperationName, generatedNames, benefit, context) {} 39 40 //===----------------------------------------------------------------------===// 41 // MatchAnyOpTypeTag Root Constructors 42 43 Pattern::Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, 44 MLIRContext *context, ArrayRef<StringRef> generatedNames) 45 : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {} 46 47 //===----------------------------------------------------------------------===// 48 // MatchInterfaceOpTypeTag Root Constructors 49 50 Pattern::Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID, 51 PatternBenefit benefit, MLIRContext *context, 52 ArrayRef<StringRef> generatedNames) 53 : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID, 54 generatedNames, benefit, context) {} 55 56 //===----------------------------------------------------------------------===// 57 // MatchTraitOpTypeTag Root Constructors 58 59 Pattern::Pattern(MatchTraitOpTypeTag tag, TypeID traitID, 60 PatternBenefit benefit, MLIRContext *context, 61 ArrayRef<StringRef> generatedNames) 62 : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames, 63 benefit, context) {} 64 65 //===----------------------------------------------------------------------===// 66 // General Constructors 67 68 Pattern::Pattern(const void *rootValue, RootKind rootKind, 69 ArrayRef<StringRef> generatedNames, PatternBenefit benefit, 70 MLIRContext *context) 71 : rootValue(rootValue), rootKind(rootKind), benefit(benefit), 72 contextAndHasBoundedRecursion(context, false) { 73 if (generatedNames.empty()) 74 return; 75 generatedOps.reserve(generatedNames.size()); 76 std::transform(generatedNames.begin(), generatedNames.end(), 77 std::back_inserter(generatedOps), [context](StringRef name) { 78 return OperationName(name, context); 79 }); 80 } 81 82 //===----------------------------------------------------------------------===// 83 // RewritePattern 84 //===----------------------------------------------------------------------===// 85 86 void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const { 87 llvm_unreachable("need to implement either matchAndRewrite or one of the " 88 "rewrite functions!"); 89 } 90 91 LogicalResult RewritePattern::match(Operation *op) const { 92 llvm_unreachable("need to implement either match or matchAndRewrite!"); 93 } 94 95 /// Out-of-line vtable anchor. 96 void RewritePattern::anchor() {} 97 98 //===----------------------------------------------------------------------===// 99 // PDLValue 100 //===----------------------------------------------------------------------===// 101 102 void PDLValue::print(raw_ostream &os) const { 103 if (!value) { 104 os << "<NULL-PDLValue>"; 105 return; 106 } 107 switch (kind) { 108 case Kind::Attribute: 109 os << cast<Attribute>(); 110 break; 111 case Kind::Operation: 112 os << *cast<Operation *>(); 113 break; 114 case Kind::Type: 115 os << cast<Type>(); 116 break; 117 case Kind::TypeRange: 118 llvm::interleaveComma(cast<TypeRange>(), os); 119 break; 120 case Kind::Value: 121 os << cast<Value>(); 122 break; 123 case Kind::ValueRange: 124 llvm::interleaveComma(cast<ValueRange>(), os); 125 break; 126 } 127 } 128 129 void PDLValue::print(raw_ostream &os, Kind kind) { 130 switch (kind) { 131 case Kind::Attribute: 132 os << "Attribute"; 133 break; 134 case Kind::Operation: 135 os << "Operation"; 136 break; 137 case Kind::Type: 138 os << "Type"; 139 break; 140 case Kind::TypeRange: 141 os << "TypeRange"; 142 break; 143 case Kind::Value: 144 os << "Value"; 145 break; 146 case Kind::ValueRange: 147 os << "ValueRange"; 148 break; 149 } 150 } 151 152 //===----------------------------------------------------------------------===// 153 // PDLPatternModule 154 //===----------------------------------------------------------------------===// 155 156 void PDLPatternModule::mergeIn(PDLPatternModule &&other) { 157 // Ignore the other module if it has no patterns. 158 if (!other.pdlModule) 159 return; 160 // Steal the other state if we have no patterns. 161 if (!pdlModule) { 162 constraintFunctions = std::move(other.constraintFunctions); 163 rewriteFunctions = std::move(other.rewriteFunctions); 164 pdlModule = std::move(other.pdlModule); 165 return; 166 } 167 // Steal the functions of the other module. 168 for (auto &it : constraintFunctions) 169 registerConstraintFunction(it.first(), std::move(it.second)); 170 for (auto &it : rewriteFunctions) 171 registerRewriteFunction(it.first(), std::move(it.second)); 172 173 // Merge the pattern operations from the other module into this one. 174 Block *block = pdlModule->getBody(); 175 block->getTerminator()->erase(); 176 block->getOperations().splice(block->end(), 177 other.pdlModule->getBody()->getOperations()); 178 } 179 180 //===----------------------------------------------------------------------===// 181 // Function Registry 182 183 void PDLPatternModule::registerConstraintFunction( 184 StringRef name, PDLConstraintFunction constraintFn) { 185 auto it = constraintFunctions.try_emplace(name, std::move(constraintFn)); 186 (void)it; 187 assert(it.second && 188 "constraint with the given name has already been registered"); 189 } 190 191 void PDLPatternModule::registerRewriteFunction(StringRef name, 192 PDLRewriteFunction rewriteFn) { 193 auto it = rewriteFunctions.try_emplace(name, std::move(rewriteFn)); 194 (void)it; 195 assert(it.second && "native rewrite function with the given name has " 196 "already been registered"); 197 } 198 199 //===----------------------------------------------------------------------===// 200 // RewriterBase 201 //===----------------------------------------------------------------------===// 202 203 RewriterBase::~RewriterBase() { 204 // Out of line to provide a vtable anchor for the class. 205 } 206 207 /// This method replaces the uses of the results of `op` with the values in 208 /// `newValues` when the provided `functor` returns true for a specific use. 209 /// The number of values in `newValues` is required to match the number of 210 /// results of `op`. 211 void RewriterBase::replaceOpWithIf( 212 Operation *op, ValueRange newValues, bool *allUsesReplaced, 213 llvm::unique_function<bool(OpOperand &) const> functor) { 214 assert(op->getNumResults() == newValues.size() && 215 "incorrect number of values to replace operation"); 216 217 // Notify the rewriter subclass that we're about to replace this root. 218 notifyRootReplaced(op); 219 220 // Replace each use of the results when the functor is true. 221 bool replacedAllUses = true; 222 for (auto it : llvm::zip(op->getResults(), newValues)) { 223 std::get<0>(it).replaceUsesWithIf(std::get<1>(it), functor); 224 replacedAllUses &= std::get<0>(it).use_empty(); 225 } 226 if (allUsesReplaced) 227 *allUsesReplaced = replacedAllUses; 228 } 229 230 /// This method replaces the uses of the results of `op` with the values in 231 /// `newValues` when a use is nested within the given `block`. The number of 232 /// values in `newValues` is required to match the number of results of `op`. 233 /// If all uses of this operation are replaced, the operation is erased. 234 void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues, 235 Block *block, bool *allUsesReplaced) { 236 replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) { 237 return block->getParentOp()->isProperAncestor(use.getOwner()); 238 }); 239 } 240 241 /// This method replaces the results of the operation with the specified list of 242 /// values. The number of provided values must match the number of results of 243 /// the operation. 244 void RewriterBase::replaceOp(Operation *op, ValueRange newValues) { 245 // Notify the rewriter subclass that we're about to replace this root. 246 notifyRootReplaced(op); 247 248 assert(op->getNumResults() == newValues.size() && 249 "incorrect # of replacement values"); 250 op->replaceAllUsesWith(newValues); 251 252 notifyOperationRemoved(op); 253 op->erase(); 254 } 255 256 /// This method erases an operation that is known to have no uses. The uses of 257 /// the given operation *must* be known to be dead. 258 void RewriterBase::eraseOp(Operation *op) { 259 assert(op->use_empty() && "expected 'op' to have no uses"); 260 notifyOperationRemoved(op); 261 op->erase(); 262 } 263 264 void RewriterBase::eraseBlock(Block *block) { 265 for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) { 266 assert(op.use_empty() && "expected 'op' to have no uses"); 267 eraseOp(&op); 268 } 269 block->erase(); 270 } 271 272 /// Merge the operations of block 'source' into the end of block 'dest'. 273 /// 'source's predecessors must be empty or only contain 'dest`. 274 /// 'argValues' is used to replace the block arguments of 'source' after 275 /// merging. 276 void RewriterBase::mergeBlocks(Block *source, Block *dest, 277 ValueRange argValues) { 278 assert(llvm::all_of(source->getPredecessors(), 279 [dest](Block *succ) { return succ == dest; }) && 280 "expected 'source' to have no predecessors or only 'dest'"); 281 assert(argValues.size() == source->getNumArguments() && 282 "incorrect # of argument replacement values"); 283 284 // Replace all of the successor arguments with the provided values. 285 for (auto it : llvm::zip(source->getArguments(), argValues)) 286 std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); 287 288 // Splice the operations of the 'source' block into the 'dest' block and erase 289 // it. 290 dest->getOperations().splice(dest->end(), source->getOperations()); 291 source->dropAllUses(); 292 source->erase(); 293 } 294 295 // Merge the operations of block 'source' before the operation 'op'. Source 296 // block should not have existing predecessors or successors. 297 void RewriterBase::mergeBlockBefore(Block *source, Operation *op, 298 ValueRange argValues) { 299 assert(source->hasNoPredecessors() && 300 "expected 'source' to have no predecessors"); 301 assert(source->hasNoSuccessors() && 302 "expected 'source' to have no successors"); 303 304 // Split the block containing 'op' into two, one containing all operations 305 // before 'op' (prologue) and another (epilogue) containing 'op' and all 306 // operations after it. 307 Block *prologue = op->getBlock(); 308 Block *epilogue = splitBlock(prologue, op->getIterator()); 309 310 // Merge the source block at the end of the prologue. 311 mergeBlocks(source, prologue, argValues); 312 313 // Merge the epilogue at the end the prologue. 314 mergeBlocks(epilogue, prologue); 315 } 316 317 /// Split the operations starting at "before" (inclusive) out of the given 318 /// block into a new block, and return it. 319 Block *RewriterBase::splitBlock(Block *block, Block::iterator before) { 320 return block->splitBlock(before); 321 } 322 323 /// 'op' and 'newOp' are known to have the same number of results, replace the 324 /// uses of op with uses of newOp 325 void RewriterBase::replaceOpWithResultsOfAnotherOp(Operation *op, 326 Operation *newOp) { 327 assert(op->getNumResults() == newOp->getNumResults() && 328 "replacement op doesn't match results of original op"); 329 if (op->getNumResults() == 1) 330 return replaceOp(op, newOp->getResult(0)); 331 return replaceOp(op, newOp->getResults()); 332 } 333 334 /// Move the blocks that belong to "region" before the given position in 335 /// another region. The two regions must be different. The caller is in 336 /// charge to update create the operation transferring the control flow to the 337 /// region and pass it the correct block arguments. 338 void RewriterBase::inlineRegionBefore(Region ®ion, Region &parent, 339 Region::iterator before) { 340 parent.getBlocks().splice(before, region.getBlocks()); 341 } 342 void RewriterBase::inlineRegionBefore(Region ®ion, Block *before) { 343 inlineRegionBefore(region, *before->getParent(), before->getIterator()); 344 } 345 346 /// Clone the blocks that belong to "region" before the given position in 347 /// another region "parent". The two regions must be different. The caller is 348 /// responsible for creating or updating the operation transferring flow of 349 /// control to the region and passing it the correct block arguments. 350 void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent, 351 Region::iterator before, 352 BlockAndValueMapping &mapping) { 353 region.cloneInto(&parent, before, mapping); 354 } 355 void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent, 356 Region::iterator before) { 357 BlockAndValueMapping mapping; 358 cloneRegionBefore(region, parent, before, mapping); 359 } 360 void RewriterBase::cloneRegionBefore(Region ®ion, Block *before) { 361 cloneRegionBefore(region, *before->getParent(), before->getIterator()); 362 } 363