1 //===- PatternMatch.cpp - Base classes for pattern match ------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "mlir/IR/PatternMatch.h" 10 #include "mlir/IR/BlockAndValueMapping.h" 11 12 using namespace mlir; 13 14 //===----------------------------------------------------------------------===// 15 // PatternBenefit 16 //===----------------------------------------------------------------------===// 17 18 PatternBenefit::PatternBenefit(unsigned benefit) : representation(benefit) { 19 assert(representation == benefit && benefit != ImpossibleToMatchSentinel && 20 "This pattern match benefit is too large to represent"); 21 } 22 23 unsigned short PatternBenefit::getBenefit() const { 24 assert(!isImpossibleToMatch() && "Pattern doesn't match"); 25 return representation; 26 } 27 28 //===----------------------------------------------------------------------===// 29 // Pattern 30 //===----------------------------------------------------------------------===// 31 32 //===----------------------------------------------------------------------===// 33 // OperationName Root Constructors 34 35 Pattern::Pattern(StringRef rootName, PatternBenefit benefit, 36 MLIRContext *context, ArrayRef<StringRef> generatedNames) 37 : Pattern(OperationName(rootName, context).getAsOpaquePointer(), 38 RootKind::OperationName, generatedNames, benefit, context) {} 39 40 //===----------------------------------------------------------------------===// 41 // MatchAnyOpTypeTag Root Constructors 42 43 Pattern::Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, 44 MLIRContext *context, ArrayRef<StringRef> generatedNames) 45 : Pattern(nullptr, RootKind::Any, generatedNames, benefit, context) {} 46 47 //===----------------------------------------------------------------------===// 48 // MatchInterfaceOpTypeTag Root Constructors 49 50 Pattern::Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID, 51 PatternBenefit benefit, MLIRContext *context, 52 ArrayRef<StringRef> generatedNames) 53 : Pattern(interfaceID.getAsOpaquePointer(), RootKind::InterfaceID, 54 generatedNames, benefit, context) {} 55 56 //===----------------------------------------------------------------------===// 57 // MatchTraitOpTypeTag Root Constructors 58 59 Pattern::Pattern(MatchTraitOpTypeTag tag, TypeID traitID, 60 PatternBenefit benefit, MLIRContext *context, 61 ArrayRef<StringRef> generatedNames) 62 : Pattern(traitID.getAsOpaquePointer(), RootKind::TraitID, generatedNames, 63 benefit, context) {} 64 65 //===----------------------------------------------------------------------===// 66 // General Constructors 67 68 Pattern::Pattern(const void *rootValue, RootKind rootKind, 69 ArrayRef<StringRef> generatedNames, PatternBenefit benefit, 70 MLIRContext *context) 71 : rootValue(rootValue), rootKind(rootKind), benefit(benefit), 72 contextAndHasBoundedRecursion(context, false) { 73 if (generatedNames.empty()) 74 return; 75 generatedOps.reserve(generatedNames.size()); 76 std::transform(generatedNames.begin(), generatedNames.end(), 77 std::back_inserter(generatedOps), [context](StringRef name) { 78 return OperationName(name, context); 79 }); 80 } 81 82 //===----------------------------------------------------------------------===// 83 // RewritePattern 84 //===----------------------------------------------------------------------===// 85 86 void RewritePattern::rewrite(Operation *op, PatternRewriter &rewriter) const { 87 llvm_unreachable("need to implement either matchAndRewrite or one of the " 88 "rewrite functions!"); 89 } 90 91 LogicalResult RewritePattern::match(Operation *op) const { 92 llvm_unreachable("need to implement either match or matchAndRewrite!"); 93 } 94 95 /// Out-of-line vtable anchor. 96 void RewritePattern::anchor() {} 97 98 //===----------------------------------------------------------------------===// 99 // PDLValue 100 //===----------------------------------------------------------------------===// 101 102 void PDLValue::print(raw_ostream &os) const { 103 if (!value) { 104 os << "<NULL-PDLValue>"; 105 return; 106 } 107 switch (kind) { 108 case Kind::Attribute: 109 os << cast<Attribute>(); 110 break; 111 case Kind::Operation: 112 os << *cast<Operation *>(); 113 break; 114 case Kind::Type: 115 os << cast<Type>(); 116 break; 117 case Kind::TypeRange: 118 llvm::interleaveComma(cast<TypeRange>(), os); 119 break; 120 case Kind::Value: 121 os << cast<Value>(); 122 break; 123 case Kind::ValueRange: 124 llvm::interleaveComma(cast<ValueRange>(), os); 125 break; 126 } 127 } 128 129 void PDLValue::print(raw_ostream &os, Kind kind) { 130 switch (kind) { 131 case Kind::Attribute: 132 os << "Attribute"; 133 break; 134 case Kind::Operation: 135 os << "Operation"; 136 break; 137 case Kind::Type: 138 os << "Type"; 139 break; 140 case Kind::TypeRange: 141 os << "TypeRange"; 142 break; 143 case Kind::Value: 144 os << "Value"; 145 break; 146 case Kind::ValueRange: 147 os << "ValueRange"; 148 break; 149 } 150 } 151 152 //===----------------------------------------------------------------------===// 153 // PDLPatternModule 154 //===----------------------------------------------------------------------===// 155 156 void PDLPatternModule::mergeIn(PDLPatternModule &&other) { 157 // Ignore the other module if it has no patterns. 158 if (!other.pdlModule) 159 return; 160 161 // Steal the functions of the other module. 162 for (auto &it : other.constraintFunctions) 163 registerConstraintFunction(it.first(), std::move(it.second)); 164 for (auto &it : other.rewriteFunctions) 165 registerRewriteFunction(it.first(), std::move(it.second)); 166 167 // Steal the other state if we have no patterns. 168 if (!pdlModule) { 169 pdlModule = std::move(other.pdlModule); 170 return; 171 } 172 173 // Merge the pattern operations from the other module into this one. 174 Block *block = pdlModule->getBody(); 175 block->getOperations().splice(block->end(), 176 other.pdlModule->getBody()->getOperations()); 177 } 178 179 //===----------------------------------------------------------------------===// 180 // Function Registry 181 182 void PDLPatternModule::registerConstraintFunction( 183 StringRef name, PDLConstraintFunction constraintFn) { 184 // TODO: Is it possible to diagnose when `name` is already registered to 185 // a function that is not equivalent to `constraintFn`? 186 // Allow existing mappings in the case multiple patterns depend on the same 187 // constraint. 188 constraintFunctions.try_emplace(name, std::move(constraintFn)); 189 } 190 191 void PDLPatternModule::registerRewriteFunction(StringRef name, 192 PDLRewriteFunction rewriteFn) { 193 // TODO: Is it possible to diagnose when `name` is already registered to 194 // a function that is not equivalent to `rewriteFn`? 195 // Allow existing mappings in the case multiple patterns depend on the same 196 // rewrite. 197 rewriteFunctions.try_emplace(name, std::move(rewriteFn)); 198 } 199 200 //===----------------------------------------------------------------------===// 201 // RewriterBase 202 //===----------------------------------------------------------------------===// 203 204 RewriterBase::~RewriterBase() { 205 // Out of line to provide a vtable anchor for the class. 206 } 207 208 /// This method replaces the uses of the results of `op` with the values in 209 /// `newValues` when the provided `functor` returns true for a specific use. 210 /// The number of values in `newValues` is required to match the number of 211 /// results of `op`. 212 void RewriterBase::replaceOpWithIf( 213 Operation *op, ValueRange newValues, bool *allUsesReplaced, 214 llvm::unique_function<bool(OpOperand &) const> functor) { 215 assert(op->getNumResults() == newValues.size() && 216 "incorrect number of values to replace operation"); 217 218 // Notify the rewriter subclass that we're about to replace this root. 219 notifyRootReplaced(op); 220 221 // Replace each use of the results when the functor is true. 222 bool replacedAllUses = true; 223 for (auto it : llvm::zip(op->getResults(), newValues)) { 224 std::get<0>(it).replaceUsesWithIf(std::get<1>(it), functor); 225 replacedAllUses &= std::get<0>(it).use_empty(); 226 } 227 if (allUsesReplaced) 228 *allUsesReplaced = replacedAllUses; 229 } 230 231 /// This method replaces the uses of the results of `op` with the values in 232 /// `newValues` when a use is nested within the given `block`. The number of 233 /// values in `newValues` is required to match the number of results of `op`. 234 /// If all uses of this operation are replaced, the operation is erased. 235 void RewriterBase::replaceOpWithinBlock(Operation *op, ValueRange newValues, 236 Block *block, bool *allUsesReplaced) { 237 replaceOpWithIf(op, newValues, allUsesReplaced, [block](OpOperand &use) { 238 return block->getParentOp()->isProperAncestor(use.getOwner()); 239 }); 240 } 241 242 /// This method replaces the results of the operation with the specified list of 243 /// values. The number of provided values must match the number of results of 244 /// the operation. 245 void RewriterBase::replaceOp(Operation *op, ValueRange newValues) { 246 // Notify the rewriter subclass that we're about to replace this root. 247 notifyRootReplaced(op); 248 249 assert(op->getNumResults() == newValues.size() && 250 "incorrect # of replacement values"); 251 op->replaceAllUsesWith(newValues); 252 253 notifyOperationRemoved(op); 254 op->erase(); 255 } 256 257 /// This method erases an operation that is known to have no uses. The uses of 258 /// the given operation *must* be known to be dead. 259 void RewriterBase::eraseOp(Operation *op) { 260 assert(op->use_empty() && "expected 'op' to have no uses"); 261 notifyOperationRemoved(op); 262 op->erase(); 263 } 264 265 void RewriterBase::eraseBlock(Block *block) { 266 for (auto &op : llvm::make_early_inc_range(llvm::reverse(*block))) { 267 assert(op.use_empty() && "expected 'op' to have no uses"); 268 eraseOp(&op); 269 } 270 block->erase(); 271 } 272 273 /// Merge the operations of block 'source' into the end of block 'dest'. 274 /// 'source's predecessors must be empty or only contain 'dest`. 275 /// 'argValues' is used to replace the block arguments of 'source' after 276 /// merging. 277 void RewriterBase::mergeBlocks(Block *source, Block *dest, 278 ValueRange argValues) { 279 assert(llvm::all_of(source->getPredecessors(), 280 [dest](Block *succ) { return succ == dest; }) && 281 "expected 'source' to have no predecessors or only 'dest'"); 282 assert(argValues.size() == source->getNumArguments() && 283 "incorrect # of argument replacement values"); 284 285 // Replace all of the successor arguments with the provided values. 286 for (auto it : llvm::zip(source->getArguments(), argValues)) 287 std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); 288 289 // Splice the operations of the 'source' block into the 'dest' block and erase 290 // it. 291 dest->getOperations().splice(dest->end(), source->getOperations()); 292 source->dropAllUses(); 293 source->erase(); 294 } 295 296 // Merge the operations of block 'source' before the operation 'op'. Source 297 // block should not have existing predecessors or successors. 298 void RewriterBase::mergeBlockBefore(Block *source, Operation *op, 299 ValueRange argValues) { 300 assert(source->hasNoPredecessors() && 301 "expected 'source' to have no predecessors"); 302 assert(source->hasNoSuccessors() && 303 "expected 'source' to have no successors"); 304 305 // Split the block containing 'op' into two, one containing all operations 306 // before 'op' (prologue) and another (epilogue) containing 'op' and all 307 // operations after it. 308 Block *prologue = op->getBlock(); 309 Block *epilogue = splitBlock(prologue, op->getIterator()); 310 311 // Merge the source block at the end of the prologue. 312 mergeBlocks(source, prologue, argValues); 313 314 // Merge the epilogue at the end the prologue. 315 mergeBlocks(epilogue, prologue); 316 } 317 318 /// Split the operations starting at "before" (inclusive) out of the given 319 /// block into a new block, and return it. 320 Block *RewriterBase::splitBlock(Block *block, Block::iterator before) { 321 return block->splitBlock(before); 322 } 323 324 /// 'op' and 'newOp' are known to have the same number of results, replace the 325 /// uses of op with uses of newOp 326 void RewriterBase::replaceOpWithResultsOfAnotherOp(Operation *op, 327 Operation *newOp) { 328 assert(op->getNumResults() == newOp->getNumResults() && 329 "replacement op doesn't match results of original op"); 330 if (op->getNumResults() == 1) 331 return replaceOp(op, newOp->getResult(0)); 332 return replaceOp(op, newOp->getResults()); 333 } 334 335 /// Move the blocks that belong to "region" before the given position in 336 /// another region. The two regions must be different. The caller is in 337 /// charge to update create the operation transferring the control flow to the 338 /// region and pass it the correct block arguments. 339 void RewriterBase::inlineRegionBefore(Region ®ion, Region &parent, 340 Region::iterator before) { 341 parent.getBlocks().splice(before, region.getBlocks()); 342 } 343 void RewriterBase::inlineRegionBefore(Region ®ion, Block *before) { 344 inlineRegionBefore(region, *before->getParent(), before->getIterator()); 345 } 346 347 /// Clone the blocks that belong to "region" before the given position in 348 /// another region "parent". The two regions must be different. The caller is 349 /// responsible for creating or updating the operation transferring flow of 350 /// control to the region and passing it the correct block arguments. 351 void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent, 352 Region::iterator before, 353 BlockAndValueMapping &mapping) { 354 region.cloneInto(&parent, before, mapping); 355 } 356 void RewriterBase::cloneRegionBefore(Region ®ion, Region &parent, 357 Region::iterator before) { 358 BlockAndValueMapping mapping; 359 cloneRegionBefore(region, parent, before, mapping); 360 } 361 void RewriterBase::cloneRegionBefore(Region ®ion, Block *before) { 362 cloneRegionBefore(region, *before->getParent(), before->getIterator()); 363 } 364