1 //===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef MLIR_IR_PATTERNMATCH_H 10 #define MLIR_IR_PATTERNMATCH_H 11 12 #include "mlir/IR/Builders.h" 13 #include "mlir/IR/BuiltinOps.h" 14 #include "llvm/ADT/FunctionExtras.h" 15 #include "llvm/Support/TypeName.h" 16 17 namespace mlir { 18 19 class PatternRewriter; 20 21 //===----------------------------------------------------------------------===// 22 // PatternBenefit class 23 //===----------------------------------------------------------------------===// 24 25 /// This class represents the benefit of a pattern match in a unitless scheme 26 /// that ranges from 0 (very little benefit) to 65K. The most common unit to 27 /// use here is the "number of operations matched" by the pattern. 28 /// 29 /// This also has a sentinel representation that can be used for patterns that 30 /// fail to match. 31 /// 32 class PatternBenefit { 33 enum { ImpossibleToMatchSentinel = 65535 }; 34 35 public: 36 PatternBenefit() = default; 37 PatternBenefit(unsigned benefit); 38 PatternBenefit(const PatternBenefit &) = default; 39 PatternBenefit &operator=(const PatternBenefit &) = default; 40 impossibleToMatch()41 static PatternBenefit impossibleToMatch() { return PatternBenefit(); } isImpossibleToMatch()42 bool isImpossibleToMatch() const { return *this == impossibleToMatch(); } 43 44 /// If the corresponding pattern can match, return its benefit. If the 45 // corresponding pattern isImpossibleToMatch() then this aborts. 46 unsigned short getBenefit() const; 47 48 bool operator==(const PatternBenefit &rhs) const { 49 return representation == rhs.representation; 50 } 51 bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); } 52 bool operator<(const PatternBenefit &rhs) const { 53 return representation < rhs.representation; 54 } 55 bool operator>(const PatternBenefit &rhs) const { return rhs < *this; } 56 bool operator<=(const PatternBenefit &rhs) const { return !(*this > rhs); } 57 bool operator>=(const PatternBenefit &rhs) const { return !(*this < rhs); } 58 59 private: 60 unsigned short representation{ImpossibleToMatchSentinel}; 61 }; 62 63 //===----------------------------------------------------------------------===// 64 // Pattern 65 //===----------------------------------------------------------------------===// 66 67 /// This class contains all of the data related to a pattern, but does not 68 /// contain any methods or logic for the actual matching. This class is solely 69 /// used to interface with the metadata of a pattern, such as the benefit or 70 /// root operation. 71 class Pattern { 72 /// This enum represents the kind of value used to select the root operations 73 /// that match this pattern. 74 enum class RootKind { 75 /// The pattern root matches "any" operation. 76 Any, 77 /// The pattern root is matched using a concrete operation name. 78 OperationName, 79 /// The pattern root is matched using an interface ID. 80 InterfaceID, 81 /// The patter root is matched using a trait ID. 82 TraitID 83 }; 84 85 public: 86 /// Return a list of operations that may be generated when rewriting an 87 /// operation instance with this pattern. getGeneratedOps()88 ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; } 89 90 /// Return the root node that this pattern matches. Patterns that can match 91 /// multiple root types return None. getRootKind()92 Optional<OperationName> getRootKind() const { 93 if (rootKind == RootKind::OperationName) 94 return OperationName::getFromOpaquePointer(rootValue); 95 return llvm::None; 96 } 97 98 /// Return the interface ID used to match the root operation of this pattern. 99 /// If the pattern does not use an interface ID for deciding the root match, 100 /// this returns None. getRootInterfaceID()101 Optional<TypeID> getRootInterfaceID() const { 102 if (rootKind == RootKind::InterfaceID) 103 return TypeID::getFromOpaquePointer(rootValue); 104 return llvm::None; 105 } 106 107 /// Return the trait ID used to match the root operation of this pattern. 108 /// If the pattern does not use a trait ID for deciding the root match, this 109 /// returns None. getRootTraitID()110 Optional<TypeID> getRootTraitID() const { 111 if (rootKind == RootKind::TraitID) 112 return TypeID::getFromOpaquePointer(rootValue); 113 return llvm::None; 114 } 115 116 /// Return the benefit (the inverse of "cost") of matching this pattern. The 117 /// benefit of a Pattern is always static - rewrites that may have dynamic 118 /// benefit can be instantiated multiple times (different Pattern instances) 119 /// for each benefit that they may return, and be guarded by different match 120 /// condition predicates. getBenefit()121 PatternBenefit getBenefit() const { return benefit; } 122 123 /// Returns true if this pattern is known to result in recursive application, 124 /// i.e. this pattern may generate IR that also matches this pattern, but is 125 /// known to bound the recursion. This signals to a rewrite driver that it is 126 /// safe to apply this pattern recursively to generated IR. hasBoundedRewriteRecursion()127 bool hasBoundedRewriteRecursion() const { 128 return contextAndHasBoundedRecursion.getInt(); 129 } 130 131 /// Return the MLIRContext used to create this pattern. getContext()132 MLIRContext *getContext() const { 133 return contextAndHasBoundedRecursion.getPointer(); 134 } 135 136 /// Return a readable name for this pattern. This name should only be used for 137 /// debugging purposes, and may be empty. getDebugName()138 StringRef getDebugName() const { return debugName; } 139 140 /// Set the human readable debug name used for this pattern. This name will 141 /// only be used for debugging purposes. setDebugName(StringRef name)142 void setDebugName(StringRef name) { debugName = name; } 143 144 /// Return the set of debug labels attached to this pattern. getDebugLabels()145 ArrayRef<StringRef> getDebugLabels() const { return debugLabels; } 146 147 /// Add the provided debug labels to this pattern. addDebugLabels(ArrayRef<StringRef> labels)148 void addDebugLabels(ArrayRef<StringRef> labels) { 149 debugLabels.append(labels.begin(), labels.end()); 150 } addDebugLabels(StringRef label)151 void addDebugLabels(StringRef label) { debugLabels.push_back(label); } 152 153 protected: 154 /// This class acts as a special tag that makes the desire to match "any" 155 /// operation type explicit. This helps to avoid unnecessary usages of this 156 /// feature, and ensures that the user is making a conscious decision. 157 struct MatchAnyOpTypeTag {}; 158 /// This class acts as a special tag that makes the desire to match any 159 /// operation that implements a given interface explicit. This helps to avoid 160 /// unnecessary usages of this feature, and ensures that the user is making a 161 /// conscious decision. 162 struct MatchInterfaceOpTypeTag {}; 163 /// This class acts as a special tag that makes the desire to match any 164 /// operation that implements a given trait explicit. This helps to avoid 165 /// unnecessary usages of this feature, and ensures that the user is making a 166 /// conscious decision. 167 struct MatchTraitOpTypeTag {}; 168 169 /// Construct a pattern with a certain benefit that matches the operation 170 /// with the given root name. 171 Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context, 172 ArrayRef<StringRef> generatedNames = {}); 173 /// Construct a pattern that may match any operation type. `generatedNames` 174 /// contains the names of operations that may be generated during a successful 175 /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any" 176 /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should 177 /// always be supplied here. 178 Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context, 179 ArrayRef<StringRef> generatedNames = {}); 180 /// Construct a pattern that may match any operation that implements the 181 /// interface defined by the provided `interfaceID`. `generatedNames` contains 182 /// the names of operations that may be generated during a successful rewrite. 183 /// `MatchInterfaceOpTypeTag` is just a tag to ensure that the "match 184 /// interface" behavior is what the user actually desired, 185 /// `MatchInterfaceOpTypeTag()` should always be supplied here. 186 Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID, 187 PatternBenefit benefit, MLIRContext *context, 188 ArrayRef<StringRef> generatedNames = {}); 189 /// Construct a pattern that may match any operation that implements the 190 /// trait defined by the provided `traitID`. `generatedNames` contains the 191 /// names of operations that may be generated during a successful rewrite. 192 /// `MatchTraitOpTypeTag` is just a tag to ensure that the "match trait" 193 /// behavior is what the user actually desired, `MatchTraitOpTypeTag()` should 194 /// always be supplied here. 195 Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit, 196 MLIRContext *context, ArrayRef<StringRef> generatedNames = {}); 197 198 /// Set the flag detailing if this pattern has bounded rewrite recursion or 199 /// not. 200 void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) { 201 contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg); 202 } 203 204 private: 205 Pattern(const void *rootValue, RootKind rootKind, 206 ArrayRef<StringRef> generatedNames, PatternBenefit benefit, 207 MLIRContext *context); 208 209 /// The value used to match the root operation of the pattern. 210 const void *rootValue; 211 RootKind rootKind; 212 213 /// The expected benefit of matching this pattern. 214 const PatternBenefit benefit; 215 216 /// The context this pattern was created from, and a boolean flag indicating 217 /// whether this pattern has bounded recursion or not. 218 llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion; 219 220 /// A list of the potential operations that may be generated when rewriting 221 /// an op with this pattern. 222 SmallVector<OperationName, 2> generatedOps; 223 224 /// A readable name for this pattern. May be empty. 225 StringRef debugName; 226 227 /// The set of debug labels attached to this pattern. 228 SmallVector<StringRef, 0> debugLabels; 229 }; 230 231 //===----------------------------------------------------------------------===// 232 // RewritePattern 233 //===----------------------------------------------------------------------===// 234 235 /// RewritePattern is the common base class for all DAG to DAG replacements. 236 /// There are two possible usages of this class: 237 /// * Multi-step RewritePattern with "match" and "rewrite" 238 /// - By overloading the "match" and "rewrite" functions, the user can 239 /// separate the concerns of matching and rewriting. 240 /// * Single-step RewritePattern with "matchAndRewrite" 241 /// - By overloading the "matchAndRewrite" function, the user can perform 242 /// the rewrite in the same call as the match. 243 /// 244 class RewritePattern : public Pattern { 245 public: 246 virtual ~RewritePattern() = default; 247 248 /// Rewrite the IR rooted at the specified operation with the result of 249 /// this pattern, generating any new operations with the specified 250 /// builder. If an unexpected error is encountered (an internal 251 /// compiler error), it is emitted through the normal MLIR diagnostic 252 /// hooks and the IR is left in a valid state. 253 virtual void rewrite(Operation *op, PatternRewriter &rewriter) const; 254 255 /// Attempt to match against code rooted at the specified operation, 256 /// which is the same operation code as getRootKind(). 257 virtual LogicalResult match(Operation *op) const; 258 259 /// Attempt to match against code rooted at the specified operation, 260 /// which is the same operation code as getRootKind(). If successful, this 261 /// function will automatically perform the rewrite. matchAndRewrite(Operation * op,PatternRewriter & rewriter)262 virtual LogicalResult matchAndRewrite(Operation *op, 263 PatternRewriter &rewriter) const { 264 if (succeeded(match(op))) { 265 rewrite(op, rewriter); 266 return success(); 267 } 268 return failure(); 269 } 270 271 /// This method provides a convenient interface for creating and initializing 272 /// derived rewrite patterns of the given type `T`. 273 template <typename T, typename... Args> create(Args &&...args)274 static std::unique_ptr<T> create(Args &&...args) { 275 std::unique_ptr<T> pattern = 276 std::make_unique<T>(std::forward<Args>(args)...); 277 initializePattern<T>(*pattern); 278 279 // Set a default debug name if one wasn't provided. 280 if (pattern->getDebugName().empty()) 281 pattern->setDebugName(llvm::getTypeName<T>()); 282 return pattern; 283 } 284 285 protected: 286 /// Inherit the base constructors from `Pattern`. 287 using Pattern::Pattern; 288 289 private: 290 /// Trait to check if T provides a `getOperationName` method. 291 template <typename T, typename... Args> 292 using has_initialize = decltype(std::declval<T>().initialize()); 293 template <typename T> 294 using detect_has_initialize = llvm::is_detected<has_initialize, T>; 295 296 /// Initialize the derived pattern by calling its `initialize` method. 297 template <typename T> 298 static std::enable_if_t<detect_has_initialize<T>::value> initializePattern(T & pattern)299 initializePattern(T &pattern) { 300 pattern.initialize(); 301 } 302 /// Empty derived pattern initializer for patterns that do not have an 303 /// initialize method. 304 template <typename T> 305 static std::enable_if_t<!detect_has_initialize<T>::value> initializePattern(T &)306 initializePattern(T &) {} 307 308 /// An anchor for the virtual table. 309 virtual void anchor(); 310 }; 311 312 namespace detail { 313 /// OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that 314 /// allows for matching and rewriting against an instance of a derived operation 315 /// class or Interface. 316 template <typename SourceOp> 317 struct OpOrInterfaceRewritePatternBase : public RewritePattern { 318 using RewritePattern::RewritePattern; 319 320 /// Wrappers around the RewritePattern methods that pass the derived op type. rewriteOpOrInterfaceRewritePatternBase321 void rewrite(Operation *op, PatternRewriter &rewriter) const final { 322 rewrite(cast<SourceOp>(op), rewriter); 323 } matchOpOrInterfaceRewritePatternBase324 LogicalResult match(Operation *op) const final { 325 return match(cast<SourceOp>(op)); 326 } matchAndRewriteOpOrInterfaceRewritePatternBase327 LogicalResult matchAndRewrite(Operation *op, 328 PatternRewriter &rewriter) const final { 329 return matchAndRewrite(cast<SourceOp>(op), rewriter); 330 } 331 332 /// Rewrite and Match methods that operate on the SourceOp type. These must be 333 /// overridden by the derived pattern class. rewriteOpOrInterfaceRewritePatternBase334 virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const { 335 llvm_unreachable("must override rewrite or matchAndRewrite"); 336 } matchOpOrInterfaceRewritePatternBase337 virtual LogicalResult match(SourceOp op) const { 338 llvm_unreachable("must override match or matchAndRewrite"); 339 } matchAndRewriteOpOrInterfaceRewritePatternBase340 virtual LogicalResult matchAndRewrite(SourceOp op, 341 PatternRewriter &rewriter) const { 342 if (succeeded(match(op))) { 343 rewrite(op, rewriter); 344 return success(); 345 } 346 return failure(); 347 } 348 }; 349 } // namespace detail 350 351 /// OpRewritePattern is a wrapper around RewritePattern that allows for 352 /// matching and rewriting against an instance of a derived operation class as 353 /// opposed to a raw Operation. 354 template <typename SourceOp> 355 struct OpRewritePattern 356 : public detail::OpOrInterfaceRewritePatternBase<SourceOp> { 357 /// Patterns must specify the root operation name they match against, and can 358 /// also specify the benefit of the pattern matching and a list of generated 359 /// ops. 360 OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1, 361 ArrayRef<StringRef> generatedNames = {}) 362 : detail::OpOrInterfaceRewritePatternBase<SourceOp>( 363 SourceOp::getOperationName(), benefit, context, generatedNames) {} 364 }; 365 366 /// OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for 367 /// matching and rewriting against an instance of an operation interface instead 368 /// of a raw Operation. 369 template <typename SourceOp> 370 struct OpInterfaceRewritePattern 371 : public detail::OpOrInterfaceRewritePatternBase<SourceOp> { 372 OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1) 373 : detail::OpOrInterfaceRewritePatternBase<SourceOp>( 374 Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(), 375 benefit, context) {} 376 }; 377 378 /// OpTraitRewritePattern is a wrapper around RewritePattern that allows for 379 /// matching and rewriting against instances of an operation that possess a 380 /// given trait. 381 template <template <typename> class TraitType> 382 class OpTraitRewritePattern : public RewritePattern { 383 public: 384 OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit = 1) RewritePattern(Pattern::MatchTraitOpTypeTag (),TypeID::get<TraitType> (),benefit,context)385 : RewritePattern(Pattern::MatchTraitOpTypeTag(), TypeID::get<TraitType>(), 386 benefit, context) {} 387 }; 388 389 //===----------------------------------------------------------------------===// 390 // RewriterBase 391 //===----------------------------------------------------------------------===// 392 393 /// This class coordinates the application of a rewrite on a set of IR, 394 /// providing a way for clients to track mutations and create new operations. 395 /// This class serves as a common API for IR mutation between pattern rewrites 396 /// and non-pattern rewrites, and facilitates the development of shared 397 /// IR transformation utilities. 398 class RewriterBase : public OpBuilder, public OpBuilder::Listener { 399 public: 400 /// Move the blocks that belong to "region" before the given position in 401 /// another region "parent". The two regions must be different. The caller 402 /// is responsible for creating or updating the operation transferring flow 403 /// of control to the region and passing it the correct block arguments. 404 virtual void inlineRegionBefore(Region ®ion, Region &parent, 405 Region::iterator before); 406 void inlineRegionBefore(Region ®ion, Block *before); 407 408 /// Clone the blocks that belong to "region" before the given position in 409 /// another region "parent". The two regions must be different. The caller is 410 /// responsible for creating or updating the operation transferring flow of 411 /// control to the region and passing it the correct block arguments. 412 virtual void cloneRegionBefore(Region ®ion, Region &parent, 413 Region::iterator before, 414 BlockAndValueMapping &mapping); 415 void cloneRegionBefore(Region ®ion, Region &parent, 416 Region::iterator before); 417 void cloneRegionBefore(Region ®ion, Block *before); 418 419 /// This method replaces the uses of the results of `op` with the values in 420 /// `newValues` when the provided `functor` returns true for a specific use. 421 /// The number of values in `newValues` is required to match the number of 422 /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of 423 /// the uses of `op` were replaced. Note that in some rewriters, the given 424 /// 'functor' may be stored beyond the lifetime of the rewrite being applied. 425 /// As such, the function should not capture by reference and instead use 426 /// value capture as necessary. 427 virtual void 428 replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced, 429 llvm::unique_function<bool(OpOperand &) const> functor); replaceOpWithIf(Operation * op,ValueRange newValues,llvm::unique_function<bool (OpOperand &)const> functor)430 void replaceOpWithIf(Operation *op, ValueRange newValues, 431 llvm::unique_function<bool(OpOperand &) const> functor) { 432 replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr, 433 std::move(functor)); 434 } 435 436 /// This method replaces the uses of the results of `op` with the values in 437 /// `newValues` when a use is nested within the given `block`. The number of 438 /// values in `newValues` is required to match the number of results of `op`. 439 /// If all uses of this operation are replaced, the operation is erased. 440 void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block, 441 bool *allUsesReplaced = nullptr); 442 443 /// This method replaces the results of the operation with the specified list 444 /// of values. The number of provided values must match the number of results 445 /// of the operation. 446 virtual void replaceOp(Operation *op, ValueRange newValues); 447 448 /// Replaces the result op with a new op that is created without verification. 449 /// The result values of the two ops must be the same types. 450 template <typename OpTy, typename... Args> replaceOpWithNewOp(Operation * op,Args &&...args)451 OpTy replaceOpWithNewOp(Operation *op, Args &&...args) { 452 auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...); 453 replaceOpWithResultsOfAnotherOp(op, newOp.getOperation()); 454 return newOp; 455 } 456 457 /// This method erases an operation that is known to have no uses. 458 virtual void eraseOp(Operation *op); 459 460 /// This method erases all operations in a block. 461 virtual void eraseBlock(Block *block); 462 463 /// Merge the operations of block 'source' into the end of block 'dest'. 464 /// 'source's predecessors must either be empty or only contain 'dest`. 465 /// 'argValues' is used to replace the block arguments of 'source' after 466 /// merging. 467 virtual void mergeBlocks(Block *source, Block *dest, 468 ValueRange argValues = llvm::None); 469 470 // Merge the operations of block 'source' before the operation 'op'. Source 471 // block should not have existing predecessors or successors. 472 void mergeBlockBefore(Block *source, Operation *op, 473 ValueRange argValues = llvm::None); 474 475 /// Split the operations starting at "before" (inclusive) out of the given 476 /// block into a new block, and return it. 477 virtual Block *splitBlock(Block *block, Block::iterator before); 478 479 /// This method is used to notify the rewriter that an in-place operation 480 /// modification is about to happen. A call to this function *must* be 481 /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`. 482 /// This is a minor efficiency win (it avoids creating a new operation and 483 /// removing the old one) but also often allows simpler code in the client. startRootUpdate(Operation * op)484 virtual void startRootUpdate(Operation *op) {} 485 486 /// This method is used to signal the end of a root update on the given 487 /// operation. This can only be called on operations that were provided to a 488 /// call to `startRootUpdate`. finalizeRootUpdate(Operation * op)489 virtual void finalizeRootUpdate(Operation *op) {} 490 491 /// This method cancels a pending root update. This can only be called on 492 /// operations that were provided to a call to `startRootUpdate`. cancelRootUpdate(Operation * op)493 virtual void cancelRootUpdate(Operation *op) {} 494 495 /// This method is a utility wrapper around a root update of an operation. It 496 /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given 497 /// callable. 498 template <typename CallableT> updateRootInPlace(Operation * root,CallableT && callable)499 void updateRootInPlace(Operation *root, CallableT &&callable) { 500 startRootUpdate(root); 501 callable(); 502 finalizeRootUpdate(root); 503 } 504 505 /// Used to notify the rewriter that the IR failed to be rewritten because of 506 /// a match failure, and provide a callback to populate a diagnostic with the 507 /// reason why the failure occurred. This method allows for derived rewriters 508 /// to optionally hook into the reason why a rewrite failed, and display it to 509 /// users. 510 template <typename CallbackT> 511 std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult> notifyMatchFailure(Location loc,CallbackT && reasonCallback)512 notifyMatchFailure(Location loc, CallbackT &&reasonCallback) { 513 #ifndef NDEBUG 514 return notifyMatchFailure(loc, 515 function_ref<void(Diagnostic &)>(reasonCallback)); 516 #else 517 return failure(); 518 #endif 519 } 520 template <typename CallbackT> 521 std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult> notifyMatchFailure(Operation * op,CallbackT && reasonCallback)522 notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) { 523 return notifyMatchFailure(op->getLoc(), 524 function_ref<void(Diagnostic &)>(reasonCallback)); 525 } 526 template <typename ArgT> notifyMatchFailure(ArgT && arg,const Twine & msg)527 LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) { 528 return notifyMatchFailure(std::forward<ArgT>(arg), 529 [&](Diagnostic &diag) { diag << msg; }); 530 } 531 template <typename ArgT> notifyMatchFailure(ArgT && arg,const char * msg)532 LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg) { 533 return notifyMatchFailure(std::forward<ArgT>(arg), Twine(msg)); 534 } 535 536 protected: 537 /// Initialize the builder with this rewriter as the listener. RewriterBase(MLIRContext * ctx)538 explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {} RewriterBase(const OpBuilder & otherBuilder)539 explicit RewriterBase(const OpBuilder &otherBuilder) 540 : OpBuilder(otherBuilder) { 541 setListener(this); 542 } 543 ~RewriterBase() override; 544 545 /// These are the callback methods that subclasses can choose to implement if 546 /// they would like to be notified about certain types of mutations. 547 548 /// Notify the rewriter that the specified operation is about to be replaced 549 /// with another set of operations. This is called before the uses of the 550 /// operation have been changed. notifyRootReplaced(Operation * op)551 virtual void notifyRootReplaced(Operation *op) {} 552 553 /// This is called on an operation that a rewrite is removing, right before 554 /// the operation is deleted. At this point, the operation has zero uses. notifyOperationRemoved(Operation * op)555 virtual void notifyOperationRemoved(Operation *op) {} 556 557 /// Notify the rewriter that the pattern failed to match the given operation, 558 /// and provide a callback to populate a diagnostic with the reason why the 559 /// failure occurred. This method allows for derived rewriters to optionally 560 /// hook into the reason why a rewrite failed, and display it to users. 561 virtual LogicalResult notifyMatchFailure(Location loc,function_ref<void (Diagnostic &)> reasonCallback)562 notifyMatchFailure(Location loc, 563 function_ref<void(Diagnostic &)> reasonCallback) { 564 return failure(); 565 } 566 567 private: 568 void operator=(const RewriterBase &) = delete; 569 RewriterBase(const RewriterBase &) = delete; 570 571 /// 'op' and 'newOp' are known to have the same number of results, replace the 572 /// uses of op with uses of newOp. 573 void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp); 574 }; 575 576 //===----------------------------------------------------------------------===// 577 // IRRewriter 578 //===----------------------------------------------------------------------===// 579 580 /// This class coordinates rewriting a piece of IR outside of a pattern rewrite, 581 /// providing a way to keep track of the mutations made to the IR. This class 582 /// should only be used in situations where another `RewriterBase` instance, 583 /// such as a `PatternRewriter`, is not available. 584 class IRRewriter : public RewriterBase { 585 public: IRRewriter(MLIRContext * ctx)586 explicit IRRewriter(MLIRContext *ctx) : RewriterBase(ctx) {} IRRewriter(const OpBuilder & builder)587 explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {} 588 }; 589 590 //===----------------------------------------------------------------------===// 591 // PatternRewriter 592 //===----------------------------------------------------------------------===// 593 594 /// A special type of `RewriterBase` that coordinates the application of a 595 /// rewrite pattern on the current IR being matched, providing a way to keep 596 /// track of any mutations made. This class should be used to perform all 597 /// necessary IR mutations within a rewrite pattern, as the pattern driver may 598 /// be tracking various state that would be invalidated when a mutation takes 599 /// place. 600 class PatternRewriter : public RewriterBase { 601 public: 602 using RewriterBase::RewriterBase; 603 }; 604 605 //===----------------------------------------------------------------------===// 606 // PDLPatternModule 607 //===----------------------------------------------------------------------===// 608 609 //===----------------------------------------------------------------------===// 610 // PDLValue 611 612 /// Storage type of byte-code interpreter values. These are passed to constraint 613 /// functions as arguments. 614 class PDLValue { 615 public: 616 /// The underlying kind of a PDL value. 617 enum class Kind { Attribute, Operation, Type, TypeRange, Value, ValueRange }; 618 619 /// Construct a new PDL value. 620 PDLValue(const PDLValue &other) = default; 621 PDLValue(std::nullptr_t = nullptr) {} PDLValue(Attribute value)622 PDLValue(Attribute value) 623 : value(value.getAsOpaquePointer()), kind(Kind::Attribute) {} PDLValue(Operation * value)624 PDLValue(Operation *value) : value(value), kind(Kind::Operation) {} PDLValue(Type value)625 PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {} PDLValue(TypeRange * value)626 PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {} PDLValue(Value value)627 PDLValue(Value value) 628 : value(value.getAsOpaquePointer()), kind(Kind::Value) {} PDLValue(ValueRange * value)629 PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {} 630 631 /// Returns true if the type of the held value is `T`. 632 template <typename T> isa()633 bool isa() const { 634 assert(value && "isa<> used on a null value"); 635 return kind == getKindOf<T>(); 636 } 637 638 /// Attempt to dynamically cast this value to type `T`, returns null if this 639 /// value is not an instance of `T`. 640 template <typename T, 641 typename ResultT = std::conditional_t< 642 std::is_convertible<T, bool>::value, T, Optional<T>>> dyn_cast()643 ResultT dyn_cast() const { 644 return isa<T>() ? castImpl<T>() : ResultT(); 645 } 646 647 /// Cast this value to type `T`, asserts if this value is not an instance of 648 /// `T`. 649 template <typename T> cast()650 T cast() const { 651 assert(isa<T>() && "expected value to be of type `T`"); 652 return castImpl<T>(); 653 } 654 655 /// Get an opaque pointer to the value. getAsOpaquePointer()656 const void *getAsOpaquePointer() const { return value; } 657 658 /// Return if this value is null or not. 659 explicit operator bool() const { return value; } 660 661 /// Return the kind of this value. getKind()662 Kind getKind() const { return kind; } 663 664 /// Print this value to the provided output stream. 665 void print(raw_ostream &os) const; 666 667 /// Print the specified value kind to an output stream. 668 static void print(raw_ostream &os, Kind kind); 669 670 private: 671 /// Find the index of a given type in a range of other types. 672 template <typename...> 673 struct index_of_t; 674 template <typename T, typename... R> 675 struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {}; 676 template <typename T, typename F, typename... R> 677 struct index_of_t<T, F, R...> 678 : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {}; 679 680 /// Return the kind used for the given T. 681 template <typename T> 682 static Kind getKindOf() { 683 return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type, 684 TypeRange, Value, ValueRange>::value); 685 } 686 687 /// The internal implementation of `cast`, that returns the underlying value 688 /// as the given type `T`. 689 template <typename T> 690 std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T> 691 castImpl() const { 692 return T::getFromOpaquePointer(value); 693 } 694 template <typename T> 695 std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T> 696 castImpl() const { 697 return *reinterpret_cast<T *>(const_cast<void *>(value)); 698 } 699 template <typename T> 700 std::enable_if_t<std::is_pointer<T>::value, T> castImpl() const { 701 return reinterpret_cast<T>(const_cast<void *>(value)); 702 } 703 704 /// The internal opaque representation of a PDLValue. 705 const void *value{nullptr}; 706 /// The kind of the opaque value. 707 Kind kind{Kind::Attribute}; 708 }; 709 710 inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) { 711 value.print(os); 712 return os; 713 } 714 715 inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) { 716 PDLValue::print(os, kind); 717 return os; 718 } 719 720 //===----------------------------------------------------------------------===// 721 // PDLResultList 722 723 /// The class represents a list of PDL results, returned by a native rewrite 724 /// method. It provides the mechanism with which to pass PDLValues back to the 725 /// PDL bytecode. 726 class PDLResultList { 727 public: 728 /// Push a new Attribute value onto the result list. 729 void push_back(Attribute value) { results.push_back(value); } 730 731 /// Push a new Operation onto the result list. 732 void push_back(Operation *value) { results.push_back(value); } 733 734 /// Push a new Type onto the result list. 735 void push_back(Type value) { results.push_back(value); } 736 737 /// Push a new TypeRange onto the result list. 738 void push_back(TypeRange value) { 739 // The lifetime of a TypeRange can't be guaranteed, so we'll need to 740 // allocate a storage for it. 741 llvm::OwningArrayRef<Type> storage(value.size()); 742 llvm::copy(value, storage.begin()); 743 allocatedTypeRanges.emplace_back(std::move(storage)); 744 typeRanges.push_back(allocatedTypeRanges.back()); 745 results.push_back(&typeRanges.back()); 746 } 747 void push_back(ValueTypeRange<OperandRange> value) { 748 typeRanges.push_back(value); 749 results.push_back(&typeRanges.back()); 750 } 751 void push_back(ValueTypeRange<ResultRange> value) { 752 typeRanges.push_back(value); 753 results.push_back(&typeRanges.back()); 754 } 755 756 /// Push a new Value onto the result list. 757 void push_back(Value value) { results.push_back(value); } 758 759 /// Push a new ValueRange onto the result list. 760 void push_back(ValueRange value) { 761 // The lifetime of a ValueRange can't be guaranteed, so we'll need to 762 // allocate a storage for it. 763 llvm::OwningArrayRef<Value> storage(value.size()); 764 llvm::copy(value, storage.begin()); 765 allocatedValueRanges.emplace_back(std::move(storage)); 766 valueRanges.push_back(allocatedValueRanges.back()); 767 results.push_back(&valueRanges.back()); 768 } 769 void push_back(OperandRange value) { 770 valueRanges.push_back(value); 771 results.push_back(&valueRanges.back()); 772 } 773 void push_back(ResultRange value) { 774 valueRanges.push_back(value); 775 results.push_back(&valueRanges.back()); 776 } 777 778 protected: 779 /// Create a new result list with the expected number of results. 780 PDLResultList(unsigned maxNumResults) { 781 // For now just reserve enough space for all of the results. We could do 782 // separate counts per range type, but it isn't really worth it unless there 783 // are a "large" number of results. 784 typeRanges.reserve(maxNumResults); 785 valueRanges.reserve(maxNumResults); 786 } 787 788 /// The PDL results held by this list. 789 SmallVector<PDLValue> results; 790 /// Memory used to store ranges held by the list. 791 SmallVector<TypeRange> typeRanges; 792 SmallVector<ValueRange> valueRanges; 793 /// Memory allocated to store ranges in the result list whose lifetime was 794 /// generated in the native function. 795 SmallVector<llvm::OwningArrayRef<Type>> allocatedTypeRanges; 796 SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges; 797 }; 798 799 //===----------------------------------------------------------------------===// 800 // PDLPatternModule 801 802 /// A generic PDL pattern constraint function. This function applies a 803 /// constraint to a given set of opaque PDLValue entities. Returns success if 804 /// the constraint successfully held, failure otherwise. 805 using PDLConstraintFunction = 806 std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>; 807 /// A native PDL rewrite function. This function performs a rewrite on the 808 /// given set of values. Any results from this rewrite that should be passed 809 /// back to PDL should be added to the provided result list. This method is only 810 /// invoked when the corresponding match was successful. 811 using PDLRewriteFunction = 812 std::function<void(PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>; 813 814 namespace detail { 815 namespace pdl_function_builder { 816 /// A utility variable that always resolves to false. This is useful for static 817 /// asserts that are always false, but only should fire in certain templated 818 /// constructs. For example, if a templated function should never be called, the 819 /// function could be defined as: 820 /// 821 /// template <typename T> 822 /// void foo() { 823 /// static_assert(always_false<T>, "This function should never be called"); 824 /// } 825 /// 826 template <class... T> 827 constexpr bool always_false = false; 828 829 //===----------------------------------------------------------------------===// 830 // PDL Function Builder: Type Processing 831 //===----------------------------------------------------------------------===// 832 833 /// This struct provides a convenient way to determine how to process a given 834 /// type as either a PDL parameter, or a result value. This allows for 835 /// supporting complex types in constraint and rewrite functions, without 836 /// requiring the user to hand-write the necessary glue code themselves. 837 /// Specializations of this class should implement the following methods to 838 /// enable support as a PDL argument or result type: 839 /// 840 /// static LogicalResult verifyAsArg( 841 /// function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue, 842 /// size_t argIdx); 843 /// 844 /// * This method verifies that the given PDLValue is valid for use as a 845 /// value of `T`. 846 /// 847 /// static T processAsArg(PDLValue pdlValue); 848 /// 849 /// * This method processes the given PDLValue as a value of `T`. 850 /// 851 /// static void processAsResult(PatternRewriter &, PDLResultList &results, 852 /// const T &value); 853 /// 854 /// * This method processes the given value of `T` as the result of a 855 /// function invocation. The method should package the value into an 856 /// appropriate form and append it to the given result list. 857 /// 858 /// If the type `T` is based on a higher order value, consider using 859 /// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify 860 /// the implementation. 861 /// 862 template <typename T, typename Enable = void> 863 struct ProcessPDLValue; 864 865 /// This struct provides a simplified model for processing types that are based 866 /// on another type, e.g. APInt is based on the handling for IntegerAttr. This 867 /// allows for building the necessary processing functions on top of the base 868 /// value instead of a PDLValue. Derived users should implement the following 869 /// (which subsume the ProcessPDLValue variants): 870 /// 871 /// static LogicalResult verifyAsArg( 872 /// function_ref<LogicalResult(const Twine &)> errorFn, 873 /// const BaseT &baseValue, size_t argIdx); 874 /// 875 /// * This method verifies that the given PDLValue is valid for use as a 876 /// value of `T`. 877 /// 878 /// static T processAsArg(BaseT baseValue); 879 /// 880 /// * This method processes the given base value as a value of `T`. 881 /// 882 template <typename T, typename BaseT> 883 struct ProcessPDLValueBasedOn { 884 static LogicalResult 885 verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, 886 PDLValue pdlValue, size_t argIdx) { 887 // Verify the base class before continuing. 888 if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx))) 889 return failure(); 890 return ProcessPDLValue<T>::verifyAsArg( 891 errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx); 892 } 893 static T processAsArg(PDLValue pdlValue) { 894 return ProcessPDLValue<T>::processAsArg( 895 ProcessPDLValue<BaseT>::processAsArg(pdlValue)); 896 } 897 898 /// Explicitly add the expected parent API to ensure the parent class 899 /// implements the necessary API (and doesn't implicitly inherit it from 900 /// somewhere else). 901 static LogicalResult 902 verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, BaseT value, 903 size_t argIdx) { 904 return success(); 905 } 906 static T processAsArg(BaseT baseValue); 907 }; 908 909 /// This struct provides a simplified model for processing types that have 910 /// "builtin" PDLValue support: 911 /// * Attribute, Operation *, Type, TypeRange, ValueRange 912 template <typename T> 913 struct ProcessBuiltinPDLValue { 914 static LogicalResult 915 verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, 916 PDLValue pdlValue, size_t argIdx) { 917 if (pdlValue) 918 return success(); 919 return errorFn("expected a non-null value for argument " + Twine(argIdx) + 920 " of type: " + llvm::getTypeName<T>()); 921 } 922 923 static T processAsArg(PDLValue pdlValue) { return pdlValue.cast<T>(); } 924 static void processAsResult(PatternRewriter &, PDLResultList &results, 925 T value) { 926 results.push_back(value); 927 } 928 }; 929 930 /// This struct provides a simplified model for processing types that inherit 931 /// from builtin PDLValue types. For example, derived attributes like 932 /// IntegerAttr, derived types like IntegerType, derived operations like 933 /// ModuleOp, Interfaces, etc. 934 template <typename T, typename BaseT> 935 struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn<T, BaseT> { 936 static LogicalResult 937 verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, 938 BaseT baseValue, size_t argIdx) { 939 return TypeSwitch<BaseT, LogicalResult>(baseValue) 940 .Case([&](T) { return success(); }) 941 .Default([&](BaseT) { 942 return errorFn("expected argument " + Twine(argIdx) + 943 " to be of type: " + llvm::getTypeName<T>()); 944 }); 945 } 946 using ProcessPDLValueBasedOn<T, BaseT>::verifyAsArg; 947 948 static T processAsArg(BaseT baseValue) { 949 return baseValue.template cast<T>(); 950 } 951 using ProcessPDLValueBasedOn<T, BaseT>::processAsArg; 952 953 static void processAsResult(PatternRewriter &, PDLResultList &results, 954 T value) { 955 results.push_back(value); 956 } 957 }; 958 959 //===----------------------------------------------------------------------===// 960 // Attribute 961 962 template <> 963 struct ProcessPDLValue<Attribute> : public ProcessBuiltinPDLValue<Attribute> {}; 964 template <typename T> 965 struct ProcessPDLValue<T, 966 std::enable_if_t<std::is_base_of<Attribute, T>::value>> 967 : public ProcessDerivedPDLValue<T, Attribute> {}; 968 969 /// Handling for various Attribute value types. 970 template <> 971 struct ProcessPDLValue<StringRef> 972 : public ProcessPDLValueBasedOn<StringRef, StringAttr> { 973 static StringRef processAsArg(StringAttr value) { return value.getValue(); } 974 using ProcessPDLValueBasedOn<StringRef, StringAttr>::processAsArg; 975 976 static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, 977 StringRef value) { 978 results.push_back(rewriter.getStringAttr(value)); 979 } 980 }; 981 template <> 982 struct ProcessPDLValue<std::string> 983 : public ProcessPDLValueBasedOn<std::string, StringAttr> { 984 template <typename T> 985 static std::string processAsArg(T value) { 986 static_assert(always_false<T>, 987 "`std::string` arguments require a string copy, use " 988 "`StringRef` for string-like arguments instead"); 989 return {}; 990 } 991 static void processAsResult(PatternRewriter &rewriter, PDLResultList &results, 992 StringRef value) { 993 results.push_back(rewriter.getStringAttr(value)); 994 } 995 }; 996 997 //===----------------------------------------------------------------------===// 998 // Operation 999 1000 template <> 1001 struct ProcessPDLValue<Operation *> 1002 : public ProcessBuiltinPDLValue<Operation *> {}; 1003 template <typename T> 1004 struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>> 1005 : public ProcessDerivedPDLValue<T, Operation *> { 1006 static T processAsArg(Operation *value) { return cast<T>(value); } 1007 }; 1008 1009 //===----------------------------------------------------------------------===// 1010 // Type 1011 1012 template <> 1013 struct ProcessPDLValue<Type> : public ProcessBuiltinPDLValue<Type> {}; 1014 template <typename T> 1015 struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>> 1016 : public ProcessDerivedPDLValue<T, Type> {}; 1017 1018 //===----------------------------------------------------------------------===// 1019 // TypeRange 1020 1021 template <> 1022 struct ProcessPDLValue<TypeRange> : public ProcessBuiltinPDLValue<TypeRange> {}; 1023 template <> 1024 struct ProcessPDLValue<ValueTypeRange<OperandRange>> { 1025 static void processAsResult(PatternRewriter &, PDLResultList &results, 1026 ValueTypeRange<OperandRange> types) { 1027 results.push_back(types); 1028 } 1029 }; 1030 template <> 1031 struct ProcessPDLValue<ValueTypeRange<ResultRange>> { 1032 static void processAsResult(PatternRewriter &, PDLResultList &results, 1033 ValueTypeRange<ResultRange> types) { 1034 results.push_back(types); 1035 } 1036 }; 1037 1038 //===----------------------------------------------------------------------===// 1039 // Value 1040 1041 template <> 1042 struct ProcessPDLValue<Value> : public ProcessBuiltinPDLValue<Value> {}; 1043 1044 //===----------------------------------------------------------------------===// 1045 // ValueRange 1046 1047 template <> 1048 struct ProcessPDLValue<ValueRange> : public ProcessBuiltinPDLValue<ValueRange> { 1049 }; 1050 template <> 1051 struct ProcessPDLValue<OperandRange> { 1052 static void processAsResult(PatternRewriter &, PDLResultList &results, 1053 OperandRange values) { 1054 results.push_back(values); 1055 } 1056 }; 1057 template <> 1058 struct ProcessPDLValue<ResultRange> { 1059 static void processAsResult(PatternRewriter &, PDLResultList &results, 1060 ResultRange values) { 1061 results.push_back(values); 1062 } 1063 }; 1064 1065 //===----------------------------------------------------------------------===// 1066 // PDL Function Builder: Argument Handling 1067 //===----------------------------------------------------------------------===// 1068 1069 /// Validate the given PDLValues match the constraints defined by the argument 1070 /// types of the given function. In the case of failure, a match failure 1071 /// diagnostic is emitted. 1072 /// FIXME: This should be completely removed in favor of `assertArgs`, but PDL 1073 /// does not currently preserve Constraint application ordering. 1074 template <typename PDLFnT, std::size_t... I> 1075 LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values, 1076 std::index_sequence<I...>) { 1077 using FnTraitsT = llvm::function_traits<PDLFnT>; 1078 1079 auto errorFn = [&](const Twine &msg) { 1080 return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg); 1081 }; 1082 LogicalResult result = success(); 1083 (void)std::initializer_list<int>{ 1084 (result = 1085 succeeded(result) 1086 ? ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>:: 1087 verifyAsArg(errorFn, values[I], I) 1088 : failure(), 1089 0)...}; 1090 return result; 1091 } 1092 1093 /// Assert that the given PDLValues match the constraints defined by the 1094 /// arguments of the given function. In the case of failure, a fatal error 1095 /// is emitted. 1096 template <typename PDLFnT, std::size_t... I> 1097 void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values, 1098 std::index_sequence<I...>) { 1099 // We only want to do verification in debug builds, same as with `assert`. 1100 #if LLVM_ENABLE_ABI_BREAKING_CHECKS 1101 using FnTraitsT = llvm::function_traits<PDLFnT>; 1102 auto errorFn = [&](const Twine &msg) -> LogicalResult { 1103 llvm::report_fatal_error(msg); 1104 }; 1105 (void)std::initializer_list<int>{ 1106 (assert(succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t< 1107 I + 1>>::verifyAsArg(errorFn, values[I], I))), 1108 0)...}; 1109 #endif 1110 } 1111 1112 //===----------------------------------------------------------------------===// 1113 // PDL Function Builder: Results Handling 1114 //===----------------------------------------------------------------------===// 1115 1116 /// Store a single result within the result list. 1117 template <typename T> 1118 static void processResults(PatternRewriter &rewriter, PDLResultList &results, 1119 T &&value) { 1120 ProcessPDLValue<T>::processAsResult(rewriter, results, 1121 std::forward<T>(value)); 1122 } 1123 1124 /// Store a std::pair<> as individual results within the result list. 1125 template <typename T1, typename T2> 1126 static void processResults(PatternRewriter &rewriter, PDLResultList &results, 1127 std::pair<T1, T2> &&pair) { 1128 processResults(rewriter, results, std::move(pair.first)); 1129 processResults(rewriter, results, std::move(pair.second)); 1130 } 1131 1132 /// Store a std::tuple<> as individual results within the result list. 1133 template <typename... Ts> 1134 static void processResults(PatternRewriter &rewriter, PDLResultList &results, 1135 std::tuple<Ts...> &&tuple) { 1136 auto applyFn = [&](auto &&...args) { 1137 // TODO: Use proper fold expressions when we have C++17. For now we use a 1138 // bogus std::initializer_list to work around C++14 limitations. 1139 (void)std::initializer_list<int>{ 1140 (processResults(rewriter, results, std::move(args)), 0)...}; 1141 }; 1142 llvm::apply_tuple(applyFn, std::move(tuple)); 1143 } 1144 1145 //===----------------------------------------------------------------------===// 1146 // PDL Constraint Builder 1147 //===----------------------------------------------------------------------===// 1148 1149 /// Process the arguments of a native constraint and invoke it. 1150 template <typename PDLFnT, std::size_t... I, 1151 typename FnTraitsT = llvm::function_traits<PDLFnT>> 1152 typename FnTraitsT::result_t 1153 processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter, 1154 ArrayRef<PDLValue> values, 1155 std::index_sequence<I...>) { 1156 return fn( 1157 rewriter, 1158 (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg( 1159 values[I]))...); 1160 } 1161 1162 /// Build a constraint function from the given function `ConstraintFnT`. This 1163 /// allows for enabling the user to define simpler, more direct constraint 1164 /// functions without needing to handle the low-level PDL goop. 1165 /// 1166 /// If the constraint function is already in the correct form, we just forward 1167 /// it directly. 1168 template <typename ConstraintFnT> 1169 std::enable_if_t< 1170 std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value, 1171 PDLConstraintFunction> 1172 buildConstraintFn(ConstraintFnT &&constraintFn) { 1173 return std::forward<ConstraintFnT>(constraintFn); 1174 } 1175 /// Otherwise, we generate a wrapper that will unpack the PDLValues in the form 1176 /// we desire. 1177 template <typename ConstraintFnT> 1178 std::enable_if_t< 1179 !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value, 1180 PDLConstraintFunction> 1181 buildConstraintFn(ConstraintFnT &&constraintFn) { 1182 return [constraintFn = std::forward<ConstraintFnT>(constraintFn)]( 1183 PatternRewriter &rewriter, 1184 ArrayRef<PDLValue> values) -> LogicalResult { 1185 auto argIndices = std::make_index_sequence< 1186 llvm::function_traits<ConstraintFnT>::num_args - 1>(); 1187 if (failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices))) 1188 return failure(); 1189 return processArgsAndInvokeConstraint(constraintFn, rewriter, values, 1190 argIndices); 1191 }; 1192 } 1193 1194 //===----------------------------------------------------------------------===// 1195 // PDL Rewrite Builder 1196 //===----------------------------------------------------------------------===// 1197 1198 /// Process the arguments of a native rewrite and invoke it. 1199 /// This overload handles the case of no return values. 1200 template <typename PDLFnT, std::size_t... I, 1201 typename FnTraitsT = llvm::function_traits<PDLFnT>> 1202 std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value> 1203 processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, 1204 PDLResultList &, ArrayRef<PDLValue> values, 1205 std::index_sequence<I...>) { 1206 fn(rewriter, 1207 (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg( 1208 values[I]))...); 1209 } 1210 /// This overload handles the case of return values, which need to be packaged 1211 /// into the result list. 1212 template <typename PDLFnT, std::size_t... I, 1213 typename FnTraitsT = llvm::function_traits<PDLFnT>> 1214 std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value> 1215 processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter, 1216 PDLResultList &results, ArrayRef<PDLValue> values, 1217 std::index_sequence<I...>) { 1218 processResults( 1219 rewriter, results, 1220 fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>:: 1221 processAsArg(values[I]))...)); 1222 } 1223 1224 /// Build a rewrite function from the given function `RewriteFnT`. This 1225 /// allows for enabling the user to define simpler, more direct rewrite 1226 /// functions without needing to handle the low-level PDL goop. 1227 /// 1228 /// If the rewrite function is already in the correct form, we just forward 1229 /// it directly. 1230 template <typename RewriteFnT> 1231 std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value, 1232 PDLRewriteFunction> 1233 buildRewriteFn(RewriteFnT &&rewriteFn) { 1234 return std::forward<RewriteFnT>(rewriteFn); 1235 } 1236 /// Otherwise, we generate a wrapper that will unpack the PDLValues in the form 1237 /// we desire. 1238 template <typename RewriteFnT> 1239 std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value, 1240 PDLRewriteFunction> 1241 buildRewriteFn(RewriteFnT &&rewriteFn) { 1242 return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)]( 1243 PatternRewriter &rewriter, PDLResultList &results, 1244 ArrayRef<PDLValue> values) { 1245 auto argIndices = 1246 std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args - 1247 1>(); 1248 assertArgs<RewriteFnT>(rewriter, values, argIndices); 1249 processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values, 1250 argIndices); 1251 }; 1252 } 1253 1254 } // namespace pdl_function_builder 1255 } // namespace detail 1256 1257 /// This class contains all of the necessary data for a set of PDL patterns, or 1258 /// pattern rewrites specified in the form of the PDL dialect. This PDL module 1259 /// contained by this pattern may contain any number of `pdl.pattern` 1260 /// operations. 1261 class PDLPatternModule { 1262 public: 1263 PDLPatternModule() = default; 1264 1265 /// Construct a PDL pattern with the given module. 1266 PDLPatternModule(OwningOpRef<ModuleOp> pdlModule) 1267 : pdlModule(std::move(pdlModule)) {} 1268 1269 /// Merge the state in `other` into this pattern module. 1270 void mergeIn(PDLPatternModule &&other); 1271 1272 /// Return the internal PDL module of this pattern. 1273 ModuleOp getModule() { return pdlModule.get(); } 1274 1275 //===--------------------------------------------------------------------===// 1276 // Function Registry 1277 1278 /// Register a constraint function with PDL. A constraint function may be 1279 /// specified in one of two ways: 1280 /// 1281 /// * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)` 1282 /// 1283 /// In this overload the arguments of the constraint function are passed via 1284 /// the low-level PDLValue form. 1285 /// 1286 /// * `LogicalResult (PatternRewriter &, ValueTs... values)` 1287 /// 1288 /// In this form the arguments of the constraint function are passed via the 1289 /// expected high level C++ type. In this form, the framework will 1290 /// automatically unwrap PDLValues and convert them to the expected ValueTs. 1291 /// For example, if the constraint function accepts a `Operation *`, the 1292 /// framework will automatically cast the input PDLValue. In the case of a 1293 /// `StringRef`, the framework will automatically unwrap the argument as a 1294 /// StringAttr and pass the underlying string value. To see the full list of 1295 /// supported types, or to see how to add handling for custom types, view 1296 /// the definition of `ProcessPDLValue` above. 1297 void registerConstraintFunction(StringRef name, 1298 PDLConstraintFunction constraintFn); 1299 template <typename ConstraintFnT> 1300 void registerConstraintFunction(StringRef name, 1301 ConstraintFnT &&constraintFn) { 1302 registerConstraintFunction(name, 1303 detail::pdl_function_builder::buildConstraintFn( 1304 std::forward<ConstraintFnT>(constraintFn))); 1305 } 1306 1307 /// Register a rewrite function with PDL. A rewrite function may be specified 1308 /// in one of two ways: 1309 /// 1310 /// * `void (PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)` 1311 /// 1312 /// In this overload the arguments of the constraint function are passed via 1313 /// the low-level PDLValue form, and the results are manually appended to 1314 /// the given result list. 1315 /// 1316 /// * `ResultT (PatternRewriter &, ValueTs... values)` 1317 /// 1318 /// In this form the arguments and result of the rewrite function are passed 1319 /// via the expected high level C++ type. In this form, the framework will 1320 /// automatically unwrap the PDLValues arguments and convert them to the 1321 /// expected ValueTs. It will also automatically handle the processing and 1322 /// packaging of the result value to the result list. For example, if the 1323 /// rewrite function takes a `Operation *`, the framework will automatically 1324 /// cast the input PDLValue. In the case of a `StringRef`, the framework 1325 /// will automatically unwrap the argument as a StringAttr and pass the 1326 /// underlying string value. In the reverse case, if the rewrite returns a 1327 /// StringRef or std::string, it will automatically package this as a 1328 /// StringAttr and append it to the result list. To see the full list of 1329 /// supported types, or to see how to add handling for custom types, view 1330 /// the definition of `ProcessPDLValue` above. 1331 void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn); 1332 template <typename RewriteFnT> 1333 void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) { 1334 registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn( 1335 std::forward<RewriteFnT>(rewriteFn))); 1336 } 1337 1338 /// Return the set of the registered constraint functions. 1339 const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const { 1340 return constraintFunctions; 1341 } 1342 llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() { 1343 return constraintFunctions; 1344 } 1345 /// Return the set of the registered rewrite functions. 1346 const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const { 1347 return rewriteFunctions; 1348 } 1349 llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() { 1350 return rewriteFunctions; 1351 } 1352 1353 /// Clear out the patterns and functions within this module. 1354 void clear() { 1355 pdlModule = nullptr; 1356 constraintFunctions.clear(); 1357 rewriteFunctions.clear(); 1358 } 1359 1360 private: 1361 /// The module containing the `pdl.pattern` operations. 1362 OwningOpRef<ModuleOp> pdlModule; 1363 1364 /// The external functions referenced from within the PDL module. 1365 llvm::StringMap<PDLConstraintFunction> constraintFunctions; 1366 llvm::StringMap<PDLRewriteFunction> rewriteFunctions; 1367 }; 1368 1369 //===----------------------------------------------------------------------===// 1370 // RewritePatternSet 1371 //===----------------------------------------------------------------------===// 1372 1373 class RewritePatternSet { 1374 using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>; 1375 1376 public: 1377 RewritePatternSet(MLIRContext *context) : context(context) {} 1378 1379 /// Construct a RewritePatternSet populated with the given pattern. 1380 RewritePatternSet(MLIRContext *context, 1381 std::unique_ptr<RewritePattern> pattern) 1382 : context(context) { 1383 nativePatterns.emplace_back(std::move(pattern)); 1384 } 1385 RewritePatternSet(PDLPatternModule &&pattern) 1386 : context(pattern.getModule()->getContext()), 1387 pdlPatterns(std::move(pattern)) {} 1388 1389 MLIRContext *getContext() const { return context; } 1390 1391 /// Return the native patterns held in this list. 1392 NativePatternListT &getNativePatterns() { return nativePatterns; } 1393 1394 /// Return the PDL patterns held in this list. 1395 PDLPatternModule &getPDLPatterns() { return pdlPatterns; } 1396 1397 /// Clear out all of the held patterns in this list. 1398 void clear() { 1399 nativePatterns.clear(); 1400 pdlPatterns.clear(); 1401 } 1402 1403 //===--------------------------------------------------------------------===// 1404 // 'add' methods for adding patterns to the set. 1405 //===--------------------------------------------------------------------===// 1406 1407 /// Add an instance of each of the pattern types 'Ts' to the pattern list with 1408 /// the given arguments. Return a reference to `this` for chaining insertions. 1409 /// Note: ConstructorArg is necessary here to separate the two variadic lists. 1410 template <typename... Ts, typename ConstructorArg, 1411 typename... ConstructorArgs, 1412 typename = std::enable_if_t<sizeof...(Ts) != 0>> 1413 RewritePatternSet &add(ConstructorArg &&arg, ConstructorArgs &&...args) { 1414 // The following expands a call to emplace_back for each of the pattern 1415 // types 'Ts'. This magic is necessary due to a limitation in the places 1416 // that a parameter pack can be expanded in c++11. 1417 // FIXME: In c++17 this can be simplified by using 'fold expressions'. 1418 (void)std::initializer_list<int>{ 1419 0, (addImpl<Ts>(/*debugLabels=*/llvm::None, 1420 std::forward<ConstructorArg>(arg), 1421 std::forward<ConstructorArgs>(args)...), 1422 0)...}; 1423 return *this; 1424 } 1425 /// An overload of the above `add` method that allows for attaching a set 1426 /// of debug labels to the attached patterns. This is useful for labeling 1427 /// groups of patterns that may be shared between multiple different 1428 /// passes/users. 1429 template <typename... Ts, typename ConstructorArg, 1430 typename... ConstructorArgs, 1431 typename = std::enable_if_t<sizeof...(Ts) != 0>> 1432 RewritePatternSet &addWithLabel(ArrayRef<StringRef> debugLabels, 1433 ConstructorArg &&arg, 1434 ConstructorArgs &&...args) { 1435 // The following expands a call to emplace_back for each of the pattern 1436 // types 'Ts'. This magic is necessary due to a limitation in the places 1437 // that a parameter pack can be expanded in c++11. 1438 // FIXME: In c++17 this can be simplified by using 'fold expressions'. 1439 (void)std::initializer_list<int>{ 1440 0, (addImpl<Ts>(debugLabels, arg, args...), 0)...}; 1441 return *this; 1442 } 1443 1444 /// Add an instance of each of the pattern types 'Ts'. Return a reference to 1445 /// `this` for chaining insertions. 1446 template <typename... Ts> 1447 RewritePatternSet &add() { 1448 (void)std::initializer_list<int>{0, (addImpl<Ts>(), 0)...}; 1449 return *this; 1450 } 1451 1452 /// Add the given native pattern to the pattern list. Return a reference to 1453 /// `this` for chaining insertions. 1454 RewritePatternSet &add(std::unique_ptr<RewritePattern> pattern) { 1455 nativePatterns.emplace_back(std::move(pattern)); 1456 return *this; 1457 } 1458 1459 /// Add the given PDL pattern to the pattern list. Return a reference to 1460 /// `this` for chaining insertions. 1461 RewritePatternSet &add(PDLPatternModule &&pattern) { 1462 pdlPatterns.mergeIn(std::move(pattern)); 1463 return *this; 1464 } 1465 1466 // Add a matchAndRewrite style pattern represented as a C function pointer. 1467 template <typename OpType> 1468 RewritePatternSet &add(LogicalResult (*implFn)(OpType, 1469 PatternRewriter &rewriter)) { 1470 struct FnPattern final : public OpRewritePattern<OpType> { 1471 FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter), 1472 MLIRContext *context) 1473 : OpRewritePattern<OpType>(context), implFn(implFn) {} 1474 1475 LogicalResult matchAndRewrite(OpType op, 1476 PatternRewriter &rewriter) const override { 1477 return implFn(op, rewriter); 1478 } 1479 1480 private: 1481 LogicalResult (*implFn)(OpType, PatternRewriter &rewriter); 1482 }; 1483 add(std::make_unique<FnPattern>(std::move(implFn), getContext())); 1484 return *this; 1485 } 1486 1487 //===--------------------------------------------------------------------===// 1488 // Pattern Insertion 1489 //===--------------------------------------------------------------------===// 1490 1491 // TODO: These are soft deprecated in favor of the 'add' methods above. 1492 1493 /// Add an instance of each of the pattern types 'Ts' to the pattern list with 1494 /// the given arguments. Return a reference to `this` for chaining insertions. 1495 /// Note: ConstructorArg is necessary here to separate the two variadic lists. 1496 template <typename... Ts, typename ConstructorArg, 1497 typename... ConstructorArgs, 1498 typename = std::enable_if_t<sizeof...(Ts) != 0>> 1499 RewritePatternSet &insert(ConstructorArg &&arg, ConstructorArgs &&...args) { 1500 // The following expands a call to emplace_back for each of the pattern 1501 // types 'Ts'. This magic is necessary due to a limitation in the places 1502 // that a parameter pack can be expanded in c++11. 1503 // FIXME: In c++17 this can be simplified by using 'fold expressions'. 1504 (void)std::initializer_list<int>{ 1505 0, (addImpl<Ts>(/*debugLabels=*/llvm::None, arg, args...), 0)...}; 1506 return *this; 1507 } 1508 1509 /// Add an instance of each of the pattern types 'Ts'. Return a reference to 1510 /// `this` for chaining insertions. 1511 template <typename... Ts> 1512 RewritePatternSet &insert() { 1513 (void)std::initializer_list<int>{0, (addImpl<Ts>(), 0)...}; 1514 return *this; 1515 } 1516 1517 /// Add the given native pattern to the pattern list. Return a reference to 1518 /// `this` for chaining insertions. 1519 RewritePatternSet &insert(std::unique_ptr<RewritePattern> pattern) { 1520 nativePatterns.emplace_back(std::move(pattern)); 1521 return *this; 1522 } 1523 1524 /// Add the given PDL pattern to the pattern list. Return a reference to 1525 /// `this` for chaining insertions. 1526 RewritePatternSet &insert(PDLPatternModule &&pattern) { 1527 pdlPatterns.mergeIn(std::move(pattern)); 1528 return *this; 1529 } 1530 1531 // Add a matchAndRewrite style pattern represented as a C function pointer. 1532 template <typename OpType> 1533 RewritePatternSet & 1534 insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) { 1535 struct FnPattern final : public OpRewritePattern<OpType> { 1536 FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter), 1537 MLIRContext *context) 1538 : OpRewritePattern<OpType>(context), implFn(implFn) { 1539 this->setDebugName(llvm::getTypeName<FnPattern>()); 1540 } 1541 1542 LogicalResult matchAndRewrite(OpType op, 1543 PatternRewriter &rewriter) const override { 1544 return implFn(op, rewriter); 1545 } 1546 1547 private: 1548 LogicalResult (*implFn)(OpType, PatternRewriter &rewriter); 1549 }; 1550 add(std::make_unique<FnPattern>(std::move(implFn), getContext())); 1551 return *this; 1552 } 1553 1554 private: 1555 /// Add an instance of the pattern type 'T'. Return a reference to `this` for 1556 /// chaining insertions. 1557 template <typename T, typename... Args> 1558 std::enable_if_t<std::is_base_of<RewritePattern, T>::value> 1559 addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) { 1560 std::unique_ptr<T> pattern = 1561 RewritePattern::create<T>(std::forward<Args>(args)...); 1562 pattern->addDebugLabels(debugLabels); 1563 nativePatterns.emplace_back(std::move(pattern)); 1564 } 1565 template <typename T, typename... Args> 1566 std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value> 1567 addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) { 1568 // TODO: Add the provided labels to the PDL pattern when PDL supports 1569 // labels. 1570 pdlPatterns.mergeIn(T(std::forward<Args>(args)...)); 1571 } 1572 1573 MLIRContext *const context; 1574 NativePatternListT nativePatterns; 1575 PDLPatternModule pdlPatterns; 1576 }; 1577 1578 } // namespace mlir 1579 1580 #endif // MLIR_IR_PATTERNMATCH_H 1581