1 //===- PatternMatch.h - PatternMatcher classes -------==---------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_IR_PATTERNMATCH_H
10 #define MLIR_IR_PATTERNMATCH_H
11 
12 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "llvm/ADT/FunctionExtras.h"
15 #include "llvm/Support/TypeName.h"
16 
17 namespace mlir {
18 
19 class PatternRewriter;
20 
21 //===----------------------------------------------------------------------===//
22 // PatternBenefit class
23 //===----------------------------------------------------------------------===//
24 
25 /// This class represents the benefit of a pattern match in a unitless scheme
26 /// that ranges from 0 (very little benefit) to 65K.  The most common unit to
27 /// use here is the "number of operations matched" by the pattern.
28 ///
29 /// This also has a sentinel representation that can be used for patterns that
30 /// fail to match.
31 ///
32 class PatternBenefit {
33   enum { ImpossibleToMatchSentinel = 65535 };
34 
35 public:
36   PatternBenefit() = default;
37   PatternBenefit(unsigned benefit);
38   PatternBenefit(const PatternBenefit &) = default;
39   PatternBenefit &operator=(const PatternBenefit &) = default;
40 
impossibleToMatch()41   static PatternBenefit impossibleToMatch() { return PatternBenefit(); }
isImpossibleToMatch()42   bool isImpossibleToMatch() const { return *this == impossibleToMatch(); }
43 
44   /// If the corresponding pattern can match, return its benefit.  If the
45   // corresponding pattern isImpossibleToMatch() then this aborts.
46   unsigned short getBenefit() const;
47 
48   bool operator==(const PatternBenefit &rhs) const {
49     return representation == rhs.representation;
50   }
51   bool operator!=(const PatternBenefit &rhs) const { return !(*this == rhs); }
52   bool operator<(const PatternBenefit &rhs) const {
53     return representation < rhs.representation;
54   }
55   bool operator>(const PatternBenefit &rhs) const { return rhs < *this; }
56   bool operator<=(const PatternBenefit &rhs) const { return !(*this > rhs); }
57   bool operator>=(const PatternBenefit &rhs) const { return !(*this < rhs); }
58 
59 private:
60   unsigned short representation{ImpossibleToMatchSentinel};
61 };
62 
63 //===----------------------------------------------------------------------===//
64 // Pattern
65 //===----------------------------------------------------------------------===//
66 
67 /// This class contains all of the data related to a pattern, but does not
68 /// contain any methods or logic for the actual matching. This class is solely
69 /// used to interface with the metadata of a pattern, such as the benefit or
70 /// root operation.
71 class Pattern {
72   /// This enum represents the kind of value used to select the root operations
73   /// that match this pattern.
74   enum class RootKind {
75     /// The pattern root matches "any" operation.
76     Any,
77     /// The pattern root is matched using a concrete operation name.
78     OperationName,
79     /// The pattern root is matched using an interface ID.
80     InterfaceID,
81     /// The patter root is matched using a trait ID.
82     TraitID
83   };
84 
85 public:
86   /// Return a list of operations that may be generated when rewriting an
87   /// operation instance with this pattern.
getGeneratedOps()88   ArrayRef<OperationName> getGeneratedOps() const { return generatedOps; }
89 
90   /// Return the root node that this pattern matches. Patterns that can match
91   /// multiple root types return None.
getRootKind()92   Optional<OperationName> getRootKind() const {
93     if (rootKind == RootKind::OperationName)
94       return OperationName::getFromOpaquePointer(rootValue);
95     return llvm::None;
96   }
97 
98   /// Return the interface ID used to match the root operation of this pattern.
99   /// If the pattern does not use an interface ID for deciding the root match,
100   /// this returns None.
getRootInterfaceID()101   Optional<TypeID> getRootInterfaceID() const {
102     if (rootKind == RootKind::InterfaceID)
103       return TypeID::getFromOpaquePointer(rootValue);
104     return llvm::None;
105   }
106 
107   /// Return the trait ID used to match the root operation of this pattern.
108   /// If the pattern does not use a trait ID for deciding the root match, this
109   /// returns None.
getRootTraitID()110   Optional<TypeID> getRootTraitID() const {
111     if (rootKind == RootKind::TraitID)
112       return TypeID::getFromOpaquePointer(rootValue);
113     return llvm::None;
114   }
115 
116   /// Return the benefit (the inverse of "cost") of matching this pattern.  The
117   /// benefit of a Pattern is always static - rewrites that may have dynamic
118   /// benefit can be instantiated multiple times (different Pattern instances)
119   /// for each benefit that they may return, and be guarded by different match
120   /// condition predicates.
getBenefit()121   PatternBenefit getBenefit() const { return benefit; }
122 
123   /// Returns true if this pattern is known to result in recursive application,
124   /// i.e. this pattern may generate IR that also matches this pattern, but is
125   /// known to bound the recursion. This signals to a rewrite driver that it is
126   /// safe to apply this pattern recursively to generated IR.
hasBoundedRewriteRecursion()127   bool hasBoundedRewriteRecursion() const {
128     return contextAndHasBoundedRecursion.getInt();
129   }
130 
131   /// Return the MLIRContext used to create this pattern.
getContext()132   MLIRContext *getContext() const {
133     return contextAndHasBoundedRecursion.getPointer();
134   }
135 
136   /// Return a readable name for this pattern. This name should only be used for
137   /// debugging purposes, and may be empty.
getDebugName()138   StringRef getDebugName() const { return debugName; }
139 
140   /// Set the human readable debug name used for this pattern. This name will
141   /// only be used for debugging purposes.
setDebugName(StringRef name)142   void setDebugName(StringRef name) { debugName = name; }
143 
144   /// Return the set of debug labels attached to this pattern.
getDebugLabels()145   ArrayRef<StringRef> getDebugLabels() const { return debugLabels; }
146 
147   /// Add the provided debug labels to this pattern.
addDebugLabels(ArrayRef<StringRef> labels)148   void addDebugLabels(ArrayRef<StringRef> labels) {
149     debugLabels.append(labels.begin(), labels.end());
150   }
addDebugLabels(StringRef label)151   void addDebugLabels(StringRef label) { debugLabels.push_back(label); }
152 
153 protected:
154   /// This class acts as a special tag that makes the desire to match "any"
155   /// operation type explicit. This helps to avoid unnecessary usages of this
156   /// feature, and ensures that the user is making a conscious decision.
157   struct MatchAnyOpTypeTag {};
158   /// This class acts as a special tag that makes the desire to match any
159   /// operation that implements a given interface explicit. This helps to avoid
160   /// unnecessary usages of this feature, and ensures that the user is making a
161   /// conscious decision.
162   struct MatchInterfaceOpTypeTag {};
163   /// This class acts as a special tag that makes the desire to match any
164   /// operation that implements a given trait explicit. This helps to avoid
165   /// unnecessary usages of this feature, and ensures that the user is making a
166   /// conscious decision.
167   struct MatchTraitOpTypeTag {};
168 
169   /// Construct a pattern with a certain benefit that matches the operation
170   /// with the given root name.
171   Pattern(StringRef rootName, PatternBenefit benefit, MLIRContext *context,
172           ArrayRef<StringRef> generatedNames = {});
173   /// Construct a pattern that may match any operation type. `generatedNames`
174   /// contains the names of operations that may be generated during a successful
175   /// rewrite. `MatchAnyOpTypeTag` is just a tag to ensure that the "match any"
176   /// behavior is what the user actually desired, `MatchAnyOpTypeTag()` should
177   /// always be supplied here.
178   Pattern(MatchAnyOpTypeTag tag, PatternBenefit benefit, MLIRContext *context,
179           ArrayRef<StringRef> generatedNames = {});
180   /// Construct a pattern that may match any operation that implements the
181   /// interface defined by the provided `interfaceID`. `generatedNames` contains
182   /// the names of operations that may be generated during a successful rewrite.
183   /// `MatchInterfaceOpTypeTag` is just a tag to ensure that the "match
184   /// interface" behavior is what the user actually desired,
185   /// `MatchInterfaceOpTypeTag()` should always be supplied here.
186   Pattern(MatchInterfaceOpTypeTag tag, TypeID interfaceID,
187           PatternBenefit benefit, MLIRContext *context,
188           ArrayRef<StringRef> generatedNames = {});
189   /// Construct a pattern that may match any operation that implements the
190   /// trait defined by the provided `traitID`. `generatedNames` contains the
191   /// names of operations that may be generated during a successful rewrite.
192   /// `MatchTraitOpTypeTag` is just a tag to ensure that the "match trait"
193   /// behavior is what the user actually desired, `MatchTraitOpTypeTag()` should
194   /// always be supplied here.
195   Pattern(MatchTraitOpTypeTag tag, TypeID traitID, PatternBenefit benefit,
196           MLIRContext *context, ArrayRef<StringRef> generatedNames = {});
197 
198   /// Set the flag detailing if this pattern has bounded rewrite recursion or
199   /// not.
200   void setHasBoundedRewriteRecursion(bool hasBoundedRecursionArg = true) {
201     contextAndHasBoundedRecursion.setInt(hasBoundedRecursionArg);
202   }
203 
204 private:
205   Pattern(const void *rootValue, RootKind rootKind,
206           ArrayRef<StringRef> generatedNames, PatternBenefit benefit,
207           MLIRContext *context);
208 
209   /// The value used to match the root operation of the pattern.
210   const void *rootValue;
211   RootKind rootKind;
212 
213   /// The expected benefit of matching this pattern.
214   const PatternBenefit benefit;
215 
216   /// The context this pattern was created from, and a boolean flag indicating
217   /// whether this pattern has bounded recursion or not.
218   llvm::PointerIntPair<MLIRContext *, 1, bool> contextAndHasBoundedRecursion;
219 
220   /// A list of the potential operations that may be generated when rewriting
221   /// an op with this pattern.
222   SmallVector<OperationName, 2> generatedOps;
223 
224   /// A readable name for this pattern. May be empty.
225   StringRef debugName;
226 
227   /// The set of debug labels attached to this pattern.
228   SmallVector<StringRef, 0> debugLabels;
229 };
230 
231 //===----------------------------------------------------------------------===//
232 // RewritePattern
233 //===----------------------------------------------------------------------===//
234 
235 /// RewritePattern is the common base class for all DAG to DAG replacements.
236 /// There are two possible usages of this class:
237 ///   * Multi-step RewritePattern with "match" and "rewrite"
238 ///     - By overloading the "match" and "rewrite" functions, the user can
239 ///       separate the concerns of matching and rewriting.
240 ///   * Single-step RewritePattern with "matchAndRewrite"
241 ///     - By overloading the "matchAndRewrite" function, the user can perform
242 ///       the rewrite in the same call as the match.
243 ///
244 class RewritePattern : public Pattern {
245 public:
246   virtual ~RewritePattern() = default;
247 
248   /// Rewrite the IR rooted at the specified operation with the result of
249   /// this pattern, generating any new operations with the specified
250   /// builder.  If an unexpected error is encountered (an internal
251   /// compiler error), it is emitted through the normal MLIR diagnostic
252   /// hooks and the IR is left in a valid state.
253   virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
254 
255   /// Attempt to match against code rooted at the specified operation,
256   /// which is the same operation code as getRootKind().
257   virtual LogicalResult match(Operation *op) const;
258 
259   /// Attempt to match against code rooted at the specified operation,
260   /// which is the same operation code as getRootKind(). If successful, this
261   /// function will automatically perform the rewrite.
matchAndRewrite(Operation * op,PatternRewriter & rewriter)262   virtual LogicalResult matchAndRewrite(Operation *op,
263                                         PatternRewriter &rewriter) const {
264     if (succeeded(match(op))) {
265       rewrite(op, rewriter);
266       return success();
267     }
268     return failure();
269   }
270 
271   /// This method provides a convenient interface for creating and initializing
272   /// derived rewrite patterns of the given type `T`.
273   template <typename T, typename... Args>
create(Args &&...args)274   static std::unique_ptr<T> create(Args &&...args) {
275     std::unique_ptr<T> pattern =
276         std::make_unique<T>(std::forward<Args>(args)...);
277     initializePattern<T>(*pattern);
278 
279     // Set a default debug name if one wasn't provided.
280     if (pattern->getDebugName().empty())
281       pattern->setDebugName(llvm::getTypeName<T>());
282     return pattern;
283   }
284 
285 protected:
286   /// Inherit the base constructors from `Pattern`.
287   using Pattern::Pattern;
288 
289 private:
290   /// Trait to check if T provides a `getOperationName` method.
291   template <typename T, typename... Args>
292   using has_initialize = decltype(std::declval<T>().initialize());
293   template <typename T>
294   using detect_has_initialize = llvm::is_detected<has_initialize, T>;
295 
296   /// Initialize the derived pattern by calling its `initialize` method.
297   template <typename T>
298   static std::enable_if_t<detect_has_initialize<T>::value>
initializePattern(T & pattern)299   initializePattern(T &pattern) {
300     pattern.initialize();
301   }
302   /// Empty derived pattern initializer for patterns that do not have an
303   /// initialize method.
304   template <typename T>
305   static std::enable_if_t<!detect_has_initialize<T>::value>
initializePattern(T &)306   initializePattern(T &) {}
307 
308   /// An anchor for the virtual table.
309   virtual void anchor();
310 };
311 
312 namespace detail {
313 /// OpOrInterfaceRewritePatternBase is a wrapper around RewritePattern that
314 /// allows for matching and rewriting against an instance of a derived operation
315 /// class or Interface.
316 template <typename SourceOp>
317 struct OpOrInterfaceRewritePatternBase : public RewritePattern {
318   using RewritePattern::RewritePattern;
319 
320   /// Wrappers around the RewritePattern methods that pass the derived op type.
rewriteOpOrInterfaceRewritePatternBase321   void rewrite(Operation *op, PatternRewriter &rewriter) const final {
322     rewrite(cast<SourceOp>(op), rewriter);
323   }
matchOpOrInterfaceRewritePatternBase324   LogicalResult match(Operation *op) const final {
325     return match(cast<SourceOp>(op));
326   }
matchAndRewriteOpOrInterfaceRewritePatternBase327   LogicalResult matchAndRewrite(Operation *op,
328                                 PatternRewriter &rewriter) const final {
329     return matchAndRewrite(cast<SourceOp>(op), rewriter);
330   }
331 
332   /// Rewrite and Match methods that operate on the SourceOp type. These must be
333   /// overridden by the derived pattern class.
rewriteOpOrInterfaceRewritePatternBase334   virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
335     llvm_unreachable("must override rewrite or matchAndRewrite");
336   }
matchOpOrInterfaceRewritePatternBase337   virtual LogicalResult match(SourceOp op) const {
338     llvm_unreachable("must override match or matchAndRewrite");
339   }
matchAndRewriteOpOrInterfaceRewritePatternBase340   virtual LogicalResult matchAndRewrite(SourceOp op,
341                                         PatternRewriter &rewriter) const {
342     if (succeeded(match(op))) {
343       rewrite(op, rewriter);
344       return success();
345     }
346     return failure();
347   }
348 };
349 } // namespace detail
350 
351 /// OpRewritePattern is a wrapper around RewritePattern that allows for
352 /// matching and rewriting against an instance of a derived operation class as
353 /// opposed to a raw Operation.
354 template <typename SourceOp>
355 struct OpRewritePattern
356     : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
357   /// Patterns must specify the root operation name they match against, and can
358   /// also specify the benefit of the pattern matching and a list of generated
359   /// ops.
360   OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1,
361                    ArrayRef<StringRef> generatedNames = {})
362       : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
363             SourceOp::getOperationName(), benefit, context, generatedNames) {}
364 };
365 
366 /// OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for
367 /// matching and rewriting against an instance of an operation interface instead
368 /// of a raw Operation.
369 template <typename SourceOp>
370 struct OpInterfaceRewritePattern
371     : public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
372   OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
373       : detail::OpOrInterfaceRewritePatternBase<SourceOp>(
374             Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),
375             benefit, context) {}
376 };
377 
378 /// OpTraitRewritePattern is a wrapper around RewritePattern that allows for
379 /// matching and rewriting against instances of an operation that possess a
380 /// given trait.
381 template <template <typename> class TraitType>
382 class OpTraitRewritePattern : public RewritePattern {
383 public:
384   OpTraitRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
RewritePattern(Pattern::MatchTraitOpTypeTag (),TypeID::get<TraitType> (),benefit,context)385       : RewritePattern(Pattern::MatchTraitOpTypeTag(), TypeID::get<TraitType>(),
386                        benefit, context) {}
387 };
388 
389 //===----------------------------------------------------------------------===//
390 // RewriterBase
391 //===----------------------------------------------------------------------===//
392 
393 /// This class coordinates the application of a rewrite on a set of IR,
394 /// providing a way for clients to track mutations and create new operations.
395 /// This class serves as a common API for IR mutation between pattern rewrites
396 /// and non-pattern rewrites, and facilitates the development of shared
397 /// IR transformation utilities.
398 class RewriterBase : public OpBuilder, public OpBuilder::Listener {
399 public:
400   /// Move the blocks that belong to "region" before the given position in
401   /// another region "parent". The two regions must be different. The caller
402   /// is responsible for creating or updating the operation transferring flow
403   /// of control to the region and passing it the correct block arguments.
404   virtual void inlineRegionBefore(Region &region, Region &parent,
405                                   Region::iterator before);
406   void inlineRegionBefore(Region &region, Block *before);
407 
408   /// Clone the blocks that belong to "region" before the given position in
409   /// another region "parent". The two regions must be different. The caller is
410   /// responsible for creating or updating the operation transferring flow of
411   /// control to the region and passing it the correct block arguments.
412   virtual void cloneRegionBefore(Region &region, Region &parent,
413                                  Region::iterator before,
414                                  BlockAndValueMapping &mapping);
415   void cloneRegionBefore(Region &region, Region &parent,
416                          Region::iterator before);
417   void cloneRegionBefore(Region &region, Block *before);
418 
419   /// This method replaces the uses of the results of `op` with the values in
420   /// `newValues` when the provided `functor` returns true for a specific use.
421   /// The number of values in `newValues` is required to match the number of
422   /// results of `op`. `allUsesReplaced`, if non-null, is set to true if all of
423   /// the uses of `op` were replaced. Note that in some rewriters, the given
424   /// 'functor' may be stored beyond the lifetime of the rewrite being applied.
425   /// As such, the function should not capture by reference and instead use
426   /// value capture as necessary.
427   virtual void
428   replaceOpWithIf(Operation *op, ValueRange newValues, bool *allUsesReplaced,
429                   llvm::unique_function<bool(OpOperand &) const> functor);
replaceOpWithIf(Operation * op,ValueRange newValues,llvm::unique_function<bool (OpOperand &)const> functor)430   void replaceOpWithIf(Operation *op, ValueRange newValues,
431                        llvm::unique_function<bool(OpOperand &) const> functor) {
432     replaceOpWithIf(op, newValues, /*allUsesReplaced=*/nullptr,
433                     std::move(functor));
434   }
435 
436   /// This method replaces the uses of the results of `op` with the values in
437   /// `newValues` when a use is nested within the given `block`. The number of
438   /// values in `newValues` is required to match the number of results of `op`.
439   /// If all uses of this operation are replaced, the operation is erased.
440   void replaceOpWithinBlock(Operation *op, ValueRange newValues, Block *block,
441                             bool *allUsesReplaced = nullptr);
442 
443   /// This method replaces the results of the operation with the specified list
444   /// of values. The number of provided values must match the number of results
445   /// of the operation.
446   virtual void replaceOp(Operation *op, ValueRange newValues);
447 
448   /// Replaces the result op with a new op that is created without verification.
449   /// The result values of the two ops must be the same types.
450   template <typename OpTy, typename... Args>
replaceOpWithNewOp(Operation * op,Args &&...args)451   OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
452     auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
453     replaceOpWithResultsOfAnotherOp(op, newOp.getOperation());
454     return newOp;
455   }
456 
457   /// This method erases an operation that is known to have no uses.
458   virtual void eraseOp(Operation *op);
459 
460   /// This method erases all operations in a block.
461   virtual void eraseBlock(Block *block);
462 
463   /// Merge the operations of block 'source' into the end of block 'dest'.
464   /// 'source's predecessors must either be empty or only contain 'dest`.
465   /// 'argValues' is used to replace the block arguments of 'source' after
466   /// merging.
467   virtual void mergeBlocks(Block *source, Block *dest,
468                            ValueRange argValues = llvm::None);
469 
470   // Merge the operations of block 'source' before the operation 'op'. Source
471   // block should not have existing predecessors or successors.
472   void mergeBlockBefore(Block *source, Operation *op,
473                         ValueRange argValues = llvm::None);
474 
475   /// Split the operations starting at "before" (inclusive) out of the given
476   /// block into a new block, and return it.
477   virtual Block *splitBlock(Block *block, Block::iterator before);
478 
479   /// This method is used to notify the rewriter that an in-place operation
480   /// modification is about to happen. A call to this function *must* be
481   /// followed by a call to either `finalizeRootUpdate` or `cancelRootUpdate`.
482   /// This is a minor efficiency win (it avoids creating a new operation and
483   /// removing the old one) but also often allows simpler code in the client.
startRootUpdate(Operation * op)484   virtual void startRootUpdate(Operation *op) {}
485 
486   /// This method is used to signal the end of a root update on the given
487   /// operation. This can only be called on operations that were provided to a
488   /// call to `startRootUpdate`.
finalizeRootUpdate(Operation * op)489   virtual void finalizeRootUpdate(Operation *op) {}
490 
491   /// This method cancels a pending root update. This can only be called on
492   /// operations that were provided to a call to `startRootUpdate`.
cancelRootUpdate(Operation * op)493   virtual void cancelRootUpdate(Operation *op) {}
494 
495   /// This method is a utility wrapper around a root update of an operation. It
496   /// wraps calls to `startRootUpdate` and `finalizeRootUpdate` around the given
497   /// callable.
498   template <typename CallableT>
updateRootInPlace(Operation * root,CallableT && callable)499   void updateRootInPlace(Operation *root, CallableT &&callable) {
500     startRootUpdate(root);
501     callable();
502     finalizeRootUpdate(root);
503   }
504 
505   /// Used to notify the rewriter that the IR failed to be rewritten because of
506   /// a match failure, and provide a callback to populate a diagnostic with the
507   /// reason why the failure occurred. This method allows for derived rewriters
508   /// to optionally hook into the reason why a rewrite failed, and display it to
509   /// users.
510   template <typename CallbackT>
511   std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
notifyMatchFailure(Location loc,CallbackT && reasonCallback)512   notifyMatchFailure(Location loc, CallbackT &&reasonCallback) {
513 #ifndef NDEBUG
514     return notifyMatchFailure(loc,
515                               function_ref<void(Diagnostic &)>(reasonCallback));
516 #else
517     return failure();
518 #endif
519   }
520   template <typename CallbackT>
521   std::enable_if_t<!std::is_convertible<CallbackT, Twine>::value, LogicalResult>
notifyMatchFailure(Operation * op,CallbackT && reasonCallback)522   notifyMatchFailure(Operation *op, CallbackT &&reasonCallback) {
523     return notifyMatchFailure(op->getLoc(),
524                               function_ref<void(Diagnostic &)>(reasonCallback));
525   }
526   template <typename ArgT>
notifyMatchFailure(ArgT && arg,const Twine & msg)527   LogicalResult notifyMatchFailure(ArgT &&arg, const Twine &msg) {
528     return notifyMatchFailure(std::forward<ArgT>(arg),
529                               [&](Diagnostic &diag) { diag << msg; });
530   }
531   template <typename ArgT>
notifyMatchFailure(ArgT && arg,const char * msg)532   LogicalResult notifyMatchFailure(ArgT &&arg, const char *msg) {
533     return notifyMatchFailure(std::forward<ArgT>(arg), Twine(msg));
534   }
535 
536 protected:
537   /// Initialize the builder with this rewriter as the listener.
RewriterBase(MLIRContext * ctx)538   explicit RewriterBase(MLIRContext *ctx) : OpBuilder(ctx, /*listener=*/this) {}
RewriterBase(const OpBuilder & otherBuilder)539   explicit RewriterBase(const OpBuilder &otherBuilder)
540       : OpBuilder(otherBuilder) {
541     setListener(this);
542   }
543   ~RewriterBase() override;
544 
545   /// These are the callback methods that subclasses can choose to implement if
546   /// they would like to be notified about certain types of mutations.
547 
548   /// Notify the rewriter that the specified operation is about to be replaced
549   /// with another set of operations. This is called before the uses of the
550   /// operation have been changed.
notifyRootReplaced(Operation * op)551   virtual void notifyRootReplaced(Operation *op) {}
552 
553   /// This is called on an operation that a rewrite is removing, right before
554   /// the operation is deleted. At this point, the operation has zero uses.
notifyOperationRemoved(Operation * op)555   virtual void notifyOperationRemoved(Operation *op) {}
556 
557   /// Notify the rewriter that the pattern failed to match the given operation,
558   /// and provide a callback to populate a diagnostic with the reason why the
559   /// failure occurred. This method allows for derived rewriters to optionally
560   /// hook into the reason why a rewrite failed, and display it to users.
561   virtual LogicalResult
notifyMatchFailure(Location loc,function_ref<void (Diagnostic &)> reasonCallback)562   notifyMatchFailure(Location loc,
563                      function_ref<void(Diagnostic &)> reasonCallback) {
564     return failure();
565   }
566 
567 private:
568   void operator=(const RewriterBase &) = delete;
569   RewriterBase(const RewriterBase &) = delete;
570 
571   /// 'op' and 'newOp' are known to have the same number of results, replace the
572   /// uses of op with uses of newOp.
573   void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp);
574 };
575 
576 //===----------------------------------------------------------------------===//
577 // IRRewriter
578 //===----------------------------------------------------------------------===//
579 
580 /// This class coordinates rewriting a piece of IR outside of a pattern rewrite,
581 /// providing a way to keep track of the mutations made to the IR. This class
582 /// should only be used in situations where another `RewriterBase` instance,
583 /// such as a `PatternRewriter`, is not available.
584 class IRRewriter : public RewriterBase {
585 public:
IRRewriter(MLIRContext * ctx)586   explicit IRRewriter(MLIRContext *ctx) : RewriterBase(ctx) {}
IRRewriter(const OpBuilder & builder)587   explicit IRRewriter(const OpBuilder &builder) : RewriterBase(builder) {}
588 };
589 
590 //===----------------------------------------------------------------------===//
591 // PatternRewriter
592 //===----------------------------------------------------------------------===//
593 
594 /// A special type of `RewriterBase` that coordinates the application of a
595 /// rewrite pattern on the current IR being matched, providing a way to keep
596 /// track of any mutations made. This class should be used to perform all
597 /// necessary IR mutations within a rewrite pattern, as the pattern driver may
598 /// be tracking various state that would be invalidated when a mutation takes
599 /// place.
600 class PatternRewriter : public RewriterBase {
601 public:
602   using RewriterBase::RewriterBase;
603 };
604 
605 //===----------------------------------------------------------------------===//
606 // PDLPatternModule
607 //===----------------------------------------------------------------------===//
608 
609 //===----------------------------------------------------------------------===//
610 // PDLValue
611 
612 /// Storage type of byte-code interpreter values. These are passed to constraint
613 /// functions as arguments.
614 class PDLValue {
615 public:
616   /// The underlying kind of a PDL value.
617   enum class Kind { Attribute, Operation, Type, TypeRange, Value, ValueRange };
618 
619   /// Construct a new PDL value.
620   PDLValue(const PDLValue &other) = default;
621   PDLValue(std::nullptr_t = nullptr) {}
PDLValue(Attribute value)622   PDLValue(Attribute value)
623       : value(value.getAsOpaquePointer()), kind(Kind::Attribute) {}
PDLValue(Operation * value)624   PDLValue(Operation *value) : value(value), kind(Kind::Operation) {}
PDLValue(Type value)625   PDLValue(Type value) : value(value.getAsOpaquePointer()), kind(Kind::Type) {}
PDLValue(TypeRange * value)626   PDLValue(TypeRange *value) : value(value), kind(Kind::TypeRange) {}
PDLValue(Value value)627   PDLValue(Value value)
628       : value(value.getAsOpaquePointer()), kind(Kind::Value) {}
PDLValue(ValueRange * value)629   PDLValue(ValueRange *value) : value(value), kind(Kind::ValueRange) {}
630 
631   /// Returns true if the type of the held value is `T`.
632   template <typename T>
isa()633   bool isa() const {
634     assert(value && "isa<> used on a null value");
635     return kind == getKindOf<T>();
636   }
637 
638   /// Attempt to dynamically cast this value to type `T`, returns null if this
639   /// value is not an instance of `T`.
640   template <typename T,
641             typename ResultT = std::conditional_t<
642                 std::is_convertible<T, bool>::value, T, Optional<T>>>
dyn_cast()643   ResultT dyn_cast() const {
644     return isa<T>() ? castImpl<T>() : ResultT();
645   }
646 
647   /// Cast this value to type `T`, asserts if this value is not an instance of
648   /// `T`.
649   template <typename T>
cast()650   T cast() const {
651     assert(isa<T>() && "expected value to be of type `T`");
652     return castImpl<T>();
653   }
654 
655   /// Get an opaque pointer to the value.
getAsOpaquePointer()656   const void *getAsOpaquePointer() const { return value; }
657 
658   /// Return if this value is null or not.
659   explicit operator bool() const { return value; }
660 
661   /// Return the kind of this value.
getKind()662   Kind getKind() const { return kind; }
663 
664   /// Print this value to the provided output stream.
665   void print(raw_ostream &os) const;
666 
667   /// Print the specified value kind to an output stream.
668   static void print(raw_ostream &os, Kind kind);
669 
670 private:
671   /// Find the index of a given type in a range of other types.
672   template <typename...>
673   struct index_of_t;
674   template <typename T, typename... R>
675   struct index_of_t<T, T, R...> : std::integral_constant<size_t, 0> {};
676   template <typename T, typename F, typename... R>
677   struct index_of_t<T, F, R...>
678       : std::integral_constant<size_t, 1 + index_of_t<T, R...>::value> {};
679 
680   /// Return the kind used for the given T.
681   template <typename T>
682   static Kind getKindOf() {
683     return static_cast<Kind>(index_of_t<T, Attribute, Operation *, Type,
684                                         TypeRange, Value, ValueRange>::value);
685   }
686 
687   /// The internal implementation of `cast`, that returns the underlying value
688   /// as the given type `T`.
689   template <typename T>
690   std::enable_if_t<llvm::is_one_of<T, Attribute, Type, Value>::value, T>
691   castImpl() const {
692     return T::getFromOpaquePointer(value);
693   }
694   template <typename T>
695   std::enable_if_t<llvm::is_one_of<T, TypeRange, ValueRange>::value, T>
696   castImpl() const {
697     return *reinterpret_cast<T *>(const_cast<void *>(value));
698   }
699   template <typename T>
700   std::enable_if_t<std::is_pointer<T>::value, T> castImpl() const {
701     return reinterpret_cast<T>(const_cast<void *>(value));
702   }
703 
704   /// The internal opaque representation of a PDLValue.
705   const void *value{nullptr};
706   /// The kind of the opaque value.
707   Kind kind{Kind::Attribute};
708 };
709 
710 inline raw_ostream &operator<<(raw_ostream &os, PDLValue value) {
711   value.print(os);
712   return os;
713 }
714 
715 inline raw_ostream &operator<<(raw_ostream &os, PDLValue::Kind kind) {
716   PDLValue::print(os, kind);
717   return os;
718 }
719 
720 //===----------------------------------------------------------------------===//
721 // PDLResultList
722 
723 /// The class represents a list of PDL results, returned by a native rewrite
724 /// method. It provides the mechanism with which to pass PDLValues back to the
725 /// PDL bytecode.
726 class PDLResultList {
727 public:
728   /// Push a new Attribute value onto the result list.
729   void push_back(Attribute value) { results.push_back(value); }
730 
731   /// Push a new Operation onto the result list.
732   void push_back(Operation *value) { results.push_back(value); }
733 
734   /// Push a new Type onto the result list.
735   void push_back(Type value) { results.push_back(value); }
736 
737   /// Push a new TypeRange onto the result list.
738   void push_back(TypeRange value) {
739     // The lifetime of a TypeRange can't be guaranteed, so we'll need to
740     // allocate a storage for it.
741     llvm::OwningArrayRef<Type> storage(value.size());
742     llvm::copy(value, storage.begin());
743     allocatedTypeRanges.emplace_back(std::move(storage));
744     typeRanges.push_back(allocatedTypeRanges.back());
745     results.push_back(&typeRanges.back());
746   }
747   void push_back(ValueTypeRange<OperandRange> value) {
748     typeRanges.push_back(value);
749     results.push_back(&typeRanges.back());
750   }
751   void push_back(ValueTypeRange<ResultRange> value) {
752     typeRanges.push_back(value);
753     results.push_back(&typeRanges.back());
754   }
755 
756   /// Push a new Value onto the result list.
757   void push_back(Value value) { results.push_back(value); }
758 
759   /// Push a new ValueRange onto the result list.
760   void push_back(ValueRange value) {
761     // The lifetime of a ValueRange can't be guaranteed, so we'll need to
762     // allocate a storage for it.
763     llvm::OwningArrayRef<Value> storage(value.size());
764     llvm::copy(value, storage.begin());
765     allocatedValueRanges.emplace_back(std::move(storage));
766     valueRanges.push_back(allocatedValueRanges.back());
767     results.push_back(&valueRanges.back());
768   }
769   void push_back(OperandRange value) {
770     valueRanges.push_back(value);
771     results.push_back(&valueRanges.back());
772   }
773   void push_back(ResultRange value) {
774     valueRanges.push_back(value);
775     results.push_back(&valueRanges.back());
776   }
777 
778 protected:
779   /// Create a new result list with the expected number of results.
780   PDLResultList(unsigned maxNumResults) {
781     // For now just reserve enough space for all of the results. We could do
782     // separate counts per range type, but it isn't really worth it unless there
783     // are a "large" number of results.
784     typeRanges.reserve(maxNumResults);
785     valueRanges.reserve(maxNumResults);
786   }
787 
788   /// The PDL results held by this list.
789   SmallVector<PDLValue> results;
790   /// Memory used to store ranges held by the list.
791   SmallVector<TypeRange> typeRanges;
792   SmallVector<ValueRange> valueRanges;
793   /// Memory allocated to store ranges in the result list whose lifetime was
794   /// generated in the native function.
795   SmallVector<llvm::OwningArrayRef<Type>> allocatedTypeRanges;
796   SmallVector<llvm::OwningArrayRef<Value>> allocatedValueRanges;
797 };
798 
799 //===----------------------------------------------------------------------===//
800 // PDLPatternModule
801 
802 /// A generic PDL pattern constraint function. This function applies a
803 /// constraint to a given set of opaque PDLValue entities. Returns success if
804 /// the constraint successfully held, failure otherwise.
805 using PDLConstraintFunction =
806     std::function<LogicalResult(PatternRewriter &, ArrayRef<PDLValue>)>;
807 /// A native PDL rewrite function. This function performs a rewrite on the
808 /// given set of values. Any results from this rewrite that should be passed
809 /// back to PDL should be added to the provided result list. This method is only
810 /// invoked when the corresponding match was successful.
811 using PDLRewriteFunction =
812     std::function<void(PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)>;
813 
814 namespace detail {
815 namespace pdl_function_builder {
816 /// A utility variable that always resolves to false. This is useful for static
817 /// asserts that are always false, but only should fire in certain templated
818 /// constructs. For example, if a templated function should never be called, the
819 /// function could be defined as:
820 ///
821 /// template <typename T>
822 /// void foo() {
823 ///  static_assert(always_false<T>, "This function should never be called");
824 /// }
825 ///
826 template <class... T>
827 constexpr bool always_false = false;
828 
829 //===----------------------------------------------------------------------===//
830 // PDL Function Builder: Type Processing
831 //===----------------------------------------------------------------------===//
832 
833 /// This struct provides a convenient way to determine how to process a given
834 /// type as either a PDL parameter, or a result value. This allows for
835 /// supporting complex types in constraint and rewrite functions, without
836 /// requiring the user to hand-write the necessary glue code themselves.
837 /// Specializations of this class should implement the following methods to
838 /// enable support as a PDL argument or result type:
839 ///
840 ///   static LogicalResult verifyAsArg(
841 ///     function_ref<LogicalResult(const Twine &)> errorFn, PDLValue pdlValue,
842 ///     size_t argIdx);
843 ///
844 ///     * This method verifies that the given PDLValue is valid for use as a
845 ///       value of `T`.
846 ///
847 ///   static T processAsArg(PDLValue pdlValue);
848 ///
849 ///     *  This method processes the given PDLValue as a value of `T`.
850 ///
851 ///   static void processAsResult(PatternRewriter &, PDLResultList &results,
852 ///                               const T &value);
853 ///
854 ///     *  This method processes the given value of `T` as the result of a
855 ///        function invocation. The method should package the value into an
856 ///        appropriate form and append it to the given result list.
857 ///
858 /// If the type `T` is based on a higher order value, consider using
859 /// `ProcessPDLValueBasedOn` as a base class of the specialization to simplify
860 /// the implementation.
861 ///
862 template <typename T, typename Enable = void>
863 struct ProcessPDLValue;
864 
865 /// This struct provides a simplified model for processing types that are based
866 /// on another type, e.g. APInt is based on the handling for IntegerAttr. This
867 /// allows for building the necessary processing functions on top of the base
868 /// value instead of a PDLValue. Derived users should implement the following
869 /// (which subsume the ProcessPDLValue variants):
870 ///
871 ///   static LogicalResult verifyAsArg(
872 ///     function_ref<LogicalResult(const Twine &)> errorFn,
873 ///     const BaseT &baseValue, size_t argIdx);
874 ///
875 ///     * This method verifies that the given PDLValue is valid for use as a
876 ///       value of `T`.
877 ///
878 ///   static T processAsArg(BaseT baseValue);
879 ///
880 ///     *  This method processes the given base value as a value of `T`.
881 ///
882 template <typename T, typename BaseT>
883 struct ProcessPDLValueBasedOn {
884   static LogicalResult
885   verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
886               PDLValue pdlValue, size_t argIdx) {
887     // Verify the base class before continuing.
888     if (failed(ProcessPDLValue<BaseT>::verifyAsArg(errorFn, pdlValue, argIdx)))
889       return failure();
890     return ProcessPDLValue<T>::verifyAsArg(
891         errorFn, ProcessPDLValue<BaseT>::processAsArg(pdlValue), argIdx);
892   }
893   static T processAsArg(PDLValue pdlValue) {
894     return ProcessPDLValue<T>::processAsArg(
895         ProcessPDLValue<BaseT>::processAsArg(pdlValue));
896   }
897 
898   /// Explicitly add the expected parent API to ensure the parent class
899   /// implements the necessary API (and doesn't implicitly inherit it from
900   /// somewhere else).
901   static LogicalResult
902   verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn, BaseT value,
903               size_t argIdx) {
904     return success();
905   }
906   static T processAsArg(BaseT baseValue);
907 };
908 
909 /// This struct provides a simplified model for processing types that have
910 /// "builtin" PDLValue support:
911 ///   * Attribute, Operation *, Type, TypeRange, ValueRange
912 template <typename T>
913 struct ProcessBuiltinPDLValue {
914   static LogicalResult
915   verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
916               PDLValue pdlValue, size_t argIdx) {
917     if (pdlValue)
918       return success();
919     return errorFn("expected a non-null value for argument " + Twine(argIdx) +
920                    " of type: " + llvm::getTypeName<T>());
921   }
922 
923   static T processAsArg(PDLValue pdlValue) { return pdlValue.cast<T>(); }
924   static void processAsResult(PatternRewriter &, PDLResultList &results,
925                               T value) {
926     results.push_back(value);
927   }
928 };
929 
930 /// This struct provides a simplified model for processing types that inherit
931 /// from builtin PDLValue types. For example, derived attributes like
932 /// IntegerAttr, derived types like IntegerType, derived operations like
933 /// ModuleOp, Interfaces, etc.
934 template <typename T, typename BaseT>
935 struct ProcessDerivedPDLValue : public ProcessPDLValueBasedOn<T, BaseT> {
936   static LogicalResult
937   verifyAsArg(function_ref<LogicalResult(const Twine &)> errorFn,
938               BaseT baseValue, size_t argIdx) {
939     return TypeSwitch<BaseT, LogicalResult>(baseValue)
940         .Case([&](T) { return success(); })
941         .Default([&](BaseT) {
942           return errorFn("expected argument " + Twine(argIdx) +
943                          " to be of type: " + llvm::getTypeName<T>());
944         });
945   }
946   using ProcessPDLValueBasedOn<T, BaseT>::verifyAsArg;
947 
948   static T processAsArg(BaseT baseValue) {
949     return baseValue.template cast<T>();
950   }
951   using ProcessPDLValueBasedOn<T, BaseT>::processAsArg;
952 
953   static void processAsResult(PatternRewriter &, PDLResultList &results,
954                               T value) {
955     results.push_back(value);
956   }
957 };
958 
959 //===----------------------------------------------------------------------===//
960 // Attribute
961 
962 template <>
963 struct ProcessPDLValue<Attribute> : public ProcessBuiltinPDLValue<Attribute> {};
964 template <typename T>
965 struct ProcessPDLValue<T,
966                        std::enable_if_t<std::is_base_of<Attribute, T>::value>>
967     : public ProcessDerivedPDLValue<T, Attribute> {};
968 
969 /// Handling for various Attribute value types.
970 template <>
971 struct ProcessPDLValue<StringRef>
972     : public ProcessPDLValueBasedOn<StringRef, StringAttr> {
973   static StringRef processAsArg(StringAttr value) { return value.getValue(); }
974   using ProcessPDLValueBasedOn<StringRef, StringAttr>::processAsArg;
975 
976   static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
977                               StringRef value) {
978     results.push_back(rewriter.getStringAttr(value));
979   }
980 };
981 template <>
982 struct ProcessPDLValue<std::string>
983     : public ProcessPDLValueBasedOn<std::string, StringAttr> {
984   template <typename T>
985   static std::string processAsArg(T value) {
986     static_assert(always_false<T>,
987                   "`std::string` arguments require a string copy, use "
988                   "`StringRef` for string-like arguments instead");
989     return {};
990   }
991   static void processAsResult(PatternRewriter &rewriter, PDLResultList &results,
992                               StringRef value) {
993     results.push_back(rewriter.getStringAttr(value));
994   }
995 };
996 
997 //===----------------------------------------------------------------------===//
998 // Operation
999 
1000 template <>
1001 struct ProcessPDLValue<Operation *>
1002     : public ProcessBuiltinPDLValue<Operation *> {};
1003 template <typename T>
1004 struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<OpState, T>::value>>
1005     : public ProcessDerivedPDLValue<T, Operation *> {
1006   static T processAsArg(Operation *value) { return cast<T>(value); }
1007 };
1008 
1009 //===----------------------------------------------------------------------===//
1010 // Type
1011 
1012 template <>
1013 struct ProcessPDLValue<Type> : public ProcessBuiltinPDLValue<Type> {};
1014 template <typename T>
1015 struct ProcessPDLValue<T, std::enable_if_t<std::is_base_of<Type, T>::value>>
1016     : public ProcessDerivedPDLValue<T, Type> {};
1017 
1018 //===----------------------------------------------------------------------===//
1019 // TypeRange
1020 
1021 template <>
1022 struct ProcessPDLValue<TypeRange> : public ProcessBuiltinPDLValue<TypeRange> {};
1023 template <>
1024 struct ProcessPDLValue<ValueTypeRange<OperandRange>> {
1025   static void processAsResult(PatternRewriter &, PDLResultList &results,
1026                               ValueTypeRange<OperandRange> types) {
1027     results.push_back(types);
1028   }
1029 };
1030 template <>
1031 struct ProcessPDLValue<ValueTypeRange<ResultRange>> {
1032   static void processAsResult(PatternRewriter &, PDLResultList &results,
1033                               ValueTypeRange<ResultRange> types) {
1034     results.push_back(types);
1035   }
1036 };
1037 
1038 //===----------------------------------------------------------------------===//
1039 // Value
1040 
1041 template <>
1042 struct ProcessPDLValue<Value> : public ProcessBuiltinPDLValue<Value> {};
1043 
1044 //===----------------------------------------------------------------------===//
1045 // ValueRange
1046 
1047 template <>
1048 struct ProcessPDLValue<ValueRange> : public ProcessBuiltinPDLValue<ValueRange> {
1049 };
1050 template <>
1051 struct ProcessPDLValue<OperandRange> {
1052   static void processAsResult(PatternRewriter &, PDLResultList &results,
1053                               OperandRange values) {
1054     results.push_back(values);
1055   }
1056 };
1057 template <>
1058 struct ProcessPDLValue<ResultRange> {
1059   static void processAsResult(PatternRewriter &, PDLResultList &results,
1060                               ResultRange values) {
1061     results.push_back(values);
1062   }
1063 };
1064 
1065 //===----------------------------------------------------------------------===//
1066 // PDL Function Builder: Argument Handling
1067 //===----------------------------------------------------------------------===//
1068 
1069 /// Validate the given PDLValues match the constraints defined by the argument
1070 /// types of the given function. In the case of failure, a match failure
1071 /// diagnostic is emitted.
1072 /// FIXME: This should be completely removed in favor of `assertArgs`, but PDL
1073 /// does not currently preserve Constraint application ordering.
1074 template <typename PDLFnT, std::size_t... I>
1075 LogicalResult verifyAsArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
1076                            std::index_sequence<I...>) {
1077   using FnTraitsT = llvm::function_traits<PDLFnT>;
1078 
1079   auto errorFn = [&](const Twine &msg) {
1080     return rewriter.notifyMatchFailure(rewriter.getUnknownLoc(), msg);
1081   };
1082   LogicalResult result = success();
1083   (void)std::initializer_list<int>{
1084       (result =
1085            succeeded(result)
1086                ? ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
1087                      verifyAsArg(errorFn, values[I], I)
1088                : failure(),
1089        0)...};
1090   return result;
1091 }
1092 
1093 /// Assert that the given PDLValues match the constraints defined by the
1094 /// arguments of the given function. In the case of failure, a fatal error
1095 /// is emitted.
1096 template <typename PDLFnT, std::size_t... I>
1097 void assertArgs(PatternRewriter &rewriter, ArrayRef<PDLValue> values,
1098                 std::index_sequence<I...>) {
1099   // We only want to do verification in debug builds, same as with `assert`.
1100 #if LLVM_ENABLE_ABI_BREAKING_CHECKS
1101   using FnTraitsT = llvm::function_traits<PDLFnT>;
1102   auto errorFn = [&](const Twine &msg) -> LogicalResult {
1103     llvm::report_fatal_error(msg);
1104   };
1105   (void)std::initializer_list<int>{
1106       (assert(succeeded(ProcessPDLValue<typename FnTraitsT::template arg_t<
1107                             I + 1>>::verifyAsArg(errorFn, values[I], I))),
1108        0)...};
1109 #endif
1110 }
1111 
1112 //===----------------------------------------------------------------------===//
1113 // PDL Function Builder: Results Handling
1114 //===----------------------------------------------------------------------===//
1115 
1116 /// Store a single result within the result list.
1117 template <typename T>
1118 static void processResults(PatternRewriter &rewriter, PDLResultList &results,
1119                            T &&value) {
1120   ProcessPDLValue<T>::processAsResult(rewriter, results,
1121                                       std::forward<T>(value));
1122 }
1123 
1124 /// Store a std::pair<> as individual results within the result list.
1125 template <typename T1, typename T2>
1126 static void processResults(PatternRewriter &rewriter, PDLResultList &results,
1127                            std::pair<T1, T2> &&pair) {
1128   processResults(rewriter, results, std::move(pair.first));
1129   processResults(rewriter, results, std::move(pair.second));
1130 }
1131 
1132 /// Store a std::tuple<> as individual results within the result list.
1133 template <typename... Ts>
1134 static void processResults(PatternRewriter &rewriter, PDLResultList &results,
1135                            std::tuple<Ts...> &&tuple) {
1136   auto applyFn = [&](auto &&...args) {
1137     // TODO: Use proper fold expressions when we have C++17. For now we use a
1138     // bogus std::initializer_list to work around C++14 limitations.
1139     (void)std::initializer_list<int>{
1140         (processResults(rewriter, results, std::move(args)), 0)...};
1141   };
1142   llvm::apply_tuple(applyFn, std::move(tuple));
1143 }
1144 
1145 //===----------------------------------------------------------------------===//
1146 // PDL Constraint Builder
1147 //===----------------------------------------------------------------------===//
1148 
1149 /// Process the arguments of a native constraint and invoke it.
1150 template <typename PDLFnT, std::size_t... I,
1151           typename FnTraitsT = llvm::function_traits<PDLFnT>>
1152 typename FnTraitsT::result_t
1153 processArgsAndInvokeConstraint(PDLFnT &fn, PatternRewriter &rewriter,
1154                                ArrayRef<PDLValue> values,
1155                                std::index_sequence<I...>) {
1156   return fn(
1157       rewriter,
1158       (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
1159           values[I]))...);
1160 }
1161 
1162 /// Build a constraint function from the given function `ConstraintFnT`. This
1163 /// allows for enabling the user to define simpler, more direct constraint
1164 /// functions without needing to handle the low-level PDL goop.
1165 ///
1166 /// If the constraint function is already in the correct form, we just forward
1167 /// it directly.
1168 template <typename ConstraintFnT>
1169 std::enable_if_t<
1170     std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
1171     PDLConstraintFunction>
1172 buildConstraintFn(ConstraintFnT &&constraintFn) {
1173   return std::forward<ConstraintFnT>(constraintFn);
1174 }
1175 /// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
1176 /// we desire.
1177 template <typename ConstraintFnT>
1178 std::enable_if_t<
1179     !std::is_convertible<ConstraintFnT, PDLConstraintFunction>::value,
1180     PDLConstraintFunction>
1181 buildConstraintFn(ConstraintFnT &&constraintFn) {
1182   return [constraintFn = std::forward<ConstraintFnT>(constraintFn)](
1183              PatternRewriter &rewriter,
1184              ArrayRef<PDLValue> values) -> LogicalResult {
1185     auto argIndices = std::make_index_sequence<
1186         llvm::function_traits<ConstraintFnT>::num_args - 1>();
1187     if (failed(verifyAsArgs<ConstraintFnT>(rewriter, values, argIndices)))
1188       return failure();
1189     return processArgsAndInvokeConstraint(constraintFn, rewriter, values,
1190                                           argIndices);
1191   };
1192 }
1193 
1194 //===----------------------------------------------------------------------===//
1195 // PDL Rewrite Builder
1196 //===----------------------------------------------------------------------===//
1197 
1198 /// Process the arguments of a native rewrite and invoke it.
1199 /// This overload handles the case of no return values.
1200 template <typename PDLFnT, std::size_t... I,
1201           typename FnTraitsT = llvm::function_traits<PDLFnT>>
1202 std::enable_if_t<std::is_same<typename FnTraitsT::result_t, void>::value>
1203 processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
1204                             PDLResultList &, ArrayRef<PDLValue> values,
1205                             std::index_sequence<I...>) {
1206   fn(rewriter,
1207      (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::processAsArg(
1208          values[I]))...);
1209 }
1210 /// This overload handles the case of return values, which need to be packaged
1211 /// into the result list.
1212 template <typename PDLFnT, std::size_t... I,
1213           typename FnTraitsT = llvm::function_traits<PDLFnT>>
1214 std::enable_if_t<!std::is_same<typename FnTraitsT::result_t, void>::value>
1215 processArgsAndInvokeRewrite(PDLFnT &fn, PatternRewriter &rewriter,
1216                             PDLResultList &results, ArrayRef<PDLValue> values,
1217                             std::index_sequence<I...>) {
1218   processResults(
1219       rewriter, results,
1220       fn(rewriter, (ProcessPDLValue<typename FnTraitsT::template arg_t<I + 1>>::
1221                         processAsArg(values[I]))...));
1222 }
1223 
1224 /// Build a rewrite function from the given function `RewriteFnT`. This
1225 /// allows for enabling the user to define simpler, more direct rewrite
1226 /// functions without needing to handle the low-level PDL goop.
1227 ///
1228 /// If the rewrite function is already in the correct form, we just forward
1229 /// it directly.
1230 template <typename RewriteFnT>
1231 std::enable_if_t<std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
1232                  PDLRewriteFunction>
1233 buildRewriteFn(RewriteFnT &&rewriteFn) {
1234   return std::forward<RewriteFnT>(rewriteFn);
1235 }
1236 /// Otherwise, we generate a wrapper that will unpack the PDLValues in the form
1237 /// we desire.
1238 template <typename RewriteFnT>
1239 std::enable_if_t<!std::is_convertible<RewriteFnT, PDLRewriteFunction>::value,
1240                  PDLRewriteFunction>
1241 buildRewriteFn(RewriteFnT &&rewriteFn) {
1242   return [rewriteFn = std::forward<RewriteFnT>(rewriteFn)](
1243              PatternRewriter &rewriter, PDLResultList &results,
1244              ArrayRef<PDLValue> values) {
1245     auto argIndices =
1246         std::make_index_sequence<llvm::function_traits<RewriteFnT>::num_args -
1247                                  1>();
1248     assertArgs<RewriteFnT>(rewriter, values, argIndices);
1249     processArgsAndInvokeRewrite(rewriteFn, rewriter, results, values,
1250                                 argIndices);
1251   };
1252 }
1253 
1254 } // namespace pdl_function_builder
1255 } // namespace detail
1256 
1257 /// This class contains all of the necessary data for a set of PDL patterns, or
1258 /// pattern rewrites specified in the form of the PDL dialect. This PDL module
1259 /// contained by this pattern may contain any number of `pdl.pattern`
1260 /// operations.
1261 class PDLPatternModule {
1262 public:
1263   PDLPatternModule() = default;
1264 
1265   /// Construct a PDL pattern with the given module.
1266   PDLPatternModule(OwningOpRef<ModuleOp> pdlModule)
1267       : pdlModule(std::move(pdlModule)) {}
1268 
1269   /// Merge the state in `other` into this pattern module.
1270   void mergeIn(PDLPatternModule &&other);
1271 
1272   /// Return the internal PDL module of this pattern.
1273   ModuleOp getModule() { return pdlModule.get(); }
1274 
1275   //===--------------------------------------------------------------------===//
1276   // Function Registry
1277 
1278   /// Register a constraint function with PDL. A constraint function may be
1279   /// specified in one of two ways:
1280   ///
1281   ///   * `LogicalResult (PatternRewriter &, ArrayRef<PDLValue>)`
1282   ///
1283   ///   In this overload the arguments of the constraint function are passed via
1284   ///   the low-level PDLValue form.
1285   ///
1286   ///   * `LogicalResult (PatternRewriter &, ValueTs... values)`
1287   ///
1288   ///   In this form the arguments of the constraint function are passed via the
1289   ///   expected high level C++ type. In this form, the framework will
1290   ///   automatically unwrap PDLValues and convert them to the expected ValueTs.
1291   ///   For example, if the constraint function accepts a `Operation *`, the
1292   ///   framework will automatically cast the input PDLValue. In the case of a
1293   ///   `StringRef`, the framework will automatically unwrap the argument as a
1294   ///   StringAttr and pass the underlying string value. To see the full list of
1295   ///   supported types, or to see how to add handling for custom types, view
1296   ///   the definition of `ProcessPDLValue` above.
1297   void registerConstraintFunction(StringRef name,
1298                                   PDLConstraintFunction constraintFn);
1299   template <typename ConstraintFnT>
1300   void registerConstraintFunction(StringRef name,
1301                                   ConstraintFnT &&constraintFn) {
1302     registerConstraintFunction(name,
1303                                detail::pdl_function_builder::buildConstraintFn(
1304                                    std::forward<ConstraintFnT>(constraintFn)));
1305   }
1306 
1307   /// Register a rewrite function with PDL. A rewrite function may be specified
1308   /// in one of two ways:
1309   ///
1310   ///   * `void (PatternRewriter &, PDLResultList &, ArrayRef<PDLValue>)`
1311   ///
1312   ///   In this overload the arguments of the constraint function are passed via
1313   ///   the low-level PDLValue form, and the results are manually appended to
1314   ///   the given result list.
1315   ///
1316   ///   * `ResultT (PatternRewriter &, ValueTs... values)`
1317   ///
1318   ///   In this form the arguments and result of the rewrite function are passed
1319   ///   via the expected high level C++ type. In this form, the framework will
1320   ///   automatically unwrap the PDLValues arguments and convert them to the
1321   ///   expected ValueTs. It will also automatically handle the processing and
1322   ///   packaging of the result value to the result list. For example, if the
1323   ///   rewrite function takes a `Operation *`, the framework will automatically
1324   ///   cast the input PDLValue. In the case of a `StringRef`, the framework
1325   ///   will automatically unwrap the argument as a StringAttr and pass the
1326   ///   underlying string value. In the reverse case, if the rewrite returns a
1327   ///   StringRef or std::string, it will automatically package this as a
1328   ///   StringAttr and append it to the result list. To see the full list of
1329   ///   supported types, or to see how to add handling for custom types, view
1330   ///   the definition of `ProcessPDLValue` above.
1331   void registerRewriteFunction(StringRef name, PDLRewriteFunction rewriteFn);
1332   template <typename RewriteFnT>
1333   void registerRewriteFunction(StringRef name, RewriteFnT &&rewriteFn) {
1334     registerRewriteFunction(name, detail::pdl_function_builder::buildRewriteFn(
1335                                       std::forward<RewriteFnT>(rewriteFn)));
1336   }
1337 
1338   /// Return the set of the registered constraint functions.
1339   const llvm::StringMap<PDLConstraintFunction> &getConstraintFunctions() const {
1340     return constraintFunctions;
1341   }
1342   llvm::StringMap<PDLConstraintFunction> takeConstraintFunctions() {
1343     return constraintFunctions;
1344   }
1345   /// Return the set of the registered rewrite functions.
1346   const llvm::StringMap<PDLRewriteFunction> &getRewriteFunctions() const {
1347     return rewriteFunctions;
1348   }
1349   llvm::StringMap<PDLRewriteFunction> takeRewriteFunctions() {
1350     return rewriteFunctions;
1351   }
1352 
1353   /// Clear out the patterns and functions within this module.
1354   void clear() {
1355     pdlModule = nullptr;
1356     constraintFunctions.clear();
1357     rewriteFunctions.clear();
1358   }
1359 
1360 private:
1361   /// The module containing the `pdl.pattern` operations.
1362   OwningOpRef<ModuleOp> pdlModule;
1363 
1364   /// The external functions referenced from within the PDL module.
1365   llvm::StringMap<PDLConstraintFunction> constraintFunctions;
1366   llvm::StringMap<PDLRewriteFunction> rewriteFunctions;
1367 };
1368 
1369 //===----------------------------------------------------------------------===//
1370 // RewritePatternSet
1371 //===----------------------------------------------------------------------===//
1372 
1373 class RewritePatternSet {
1374   using NativePatternListT = std::vector<std::unique_ptr<RewritePattern>>;
1375 
1376 public:
1377   RewritePatternSet(MLIRContext *context) : context(context) {}
1378 
1379   /// Construct a RewritePatternSet populated with the given pattern.
1380   RewritePatternSet(MLIRContext *context,
1381                     std::unique_ptr<RewritePattern> pattern)
1382       : context(context) {
1383     nativePatterns.emplace_back(std::move(pattern));
1384   }
1385   RewritePatternSet(PDLPatternModule &&pattern)
1386       : context(pattern.getModule()->getContext()),
1387         pdlPatterns(std::move(pattern)) {}
1388 
1389   MLIRContext *getContext() const { return context; }
1390 
1391   /// Return the native patterns held in this list.
1392   NativePatternListT &getNativePatterns() { return nativePatterns; }
1393 
1394   /// Return the PDL patterns held in this list.
1395   PDLPatternModule &getPDLPatterns() { return pdlPatterns; }
1396 
1397   /// Clear out all of the held patterns in this list.
1398   void clear() {
1399     nativePatterns.clear();
1400     pdlPatterns.clear();
1401   }
1402 
1403   //===--------------------------------------------------------------------===//
1404   // 'add' methods for adding patterns to the set.
1405   //===--------------------------------------------------------------------===//
1406 
1407   /// Add an instance of each of the pattern types 'Ts' to the pattern list with
1408   /// the given arguments. Return a reference to `this` for chaining insertions.
1409   /// Note: ConstructorArg is necessary here to separate the two variadic lists.
1410   template <typename... Ts, typename ConstructorArg,
1411             typename... ConstructorArgs,
1412             typename = std::enable_if_t<sizeof...(Ts) != 0>>
1413   RewritePatternSet &add(ConstructorArg &&arg, ConstructorArgs &&...args) {
1414     // The following expands a call to emplace_back for each of the pattern
1415     // types 'Ts'. This magic is necessary due to a limitation in the places
1416     // that a parameter pack can be expanded in c++11.
1417     // FIXME: In c++17 this can be simplified by using 'fold expressions'.
1418     (void)std::initializer_list<int>{
1419         0, (addImpl<Ts>(/*debugLabels=*/llvm::None,
1420                         std::forward<ConstructorArg>(arg),
1421                         std::forward<ConstructorArgs>(args)...),
1422             0)...};
1423     return *this;
1424   }
1425   /// An overload of the above `add` method that allows for attaching a set
1426   /// of debug labels to the attached patterns. This is useful for labeling
1427   /// groups of patterns that may be shared between multiple different
1428   /// passes/users.
1429   template <typename... Ts, typename ConstructorArg,
1430             typename... ConstructorArgs,
1431             typename = std::enable_if_t<sizeof...(Ts) != 0>>
1432   RewritePatternSet &addWithLabel(ArrayRef<StringRef> debugLabels,
1433                                   ConstructorArg &&arg,
1434                                   ConstructorArgs &&...args) {
1435     // The following expands a call to emplace_back for each of the pattern
1436     // types 'Ts'. This magic is necessary due to a limitation in the places
1437     // that a parameter pack can be expanded in c++11.
1438     // FIXME: In c++17 this can be simplified by using 'fold expressions'.
1439     (void)std::initializer_list<int>{
1440         0, (addImpl<Ts>(debugLabels, arg, args...), 0)...};
1441     return *this;
1442   }
1443 
1444   /// Add an instance of each of the pattern types 'Ts'. Return a reference to
1445   /// `this` for chaining insertions.
1446   template <typename... Ts>
1447   RewritePatternSet &add() {
1448     (void)std::initializer_list<int>{0, (addImpl<Ts>(), 0)...};
1449     return *this;
1450   }
1451 
1452   /// Add the given native pattern to the pattern list. Return a reference to
1453   /// `this` for chaining insertions.
1454   RewritePatternSet &add(std::unique_ptr<RewritePattern> pattern) {
1455     nativePatterns.emplace_back(std::move(pattern));
1456     return *this;
1457   }
1458 
1459   /// Add the given PDL pattern to the pattern list. Return a reference to
1460   /// `this` for chaining insertions.
1461   RewritePatternSet &add(PDLPatternModule &&pattern) {
1462     pdlPatterns.mergeIn(std::move(pattern));
1463     return *this;
1464   }
1465 
1466   // Add a matchAndRewrite style pattern represented as a C function pointer.
1467   template <typename OpType>
1468   RewritePatternSet &add(LogicalResult (*implFn)(OpType,
1469                                                  PatternRewriter &rewriter)) {
1470     struct FnPattern final : public OpRewritePattern<OpType> {
1471       FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
1472                 MLIRContext *context)
1473           : OpRewritePattern<OpType>(context), implFn(implFn) {}
1474 
1475       LogicalResult matchAndRewrite(OpType op,
1476                                     PatternRewriter &rewriter) const override {
1477         return implFn(op, rewriter);
1478       }
1479 
1480     private:
1481       LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
1482     };
1483     add(std::make_unique<FnPattern>(std::move(implFn), getContext()));
1484     return *this;
1485   }
1486 
1487   //===--------------------------------------------------------------------===//
1488   // Pattern Insertion
1489   //===--------------------------------------------------------------------===//
1490 
1491   // TODO: These are soft deprecated in favor of the 'add' methods above.
1492 
1493   /// Add an instance of each of the pattern types 'Ts' to the pattern list with
1494   /// the given arguments. Return a reference to `this` for chaining insertions.
1495   /// Note: ConstructorArg is necessary here to separate the two variadic lists.
1496   template <typename... Ts, typename ConstructorArg,
1497             typename... ConstructorArgs,
1498             typename = std::enable_if_t<sizeof...(Ts) != 0>>
1499   RewritePatternSet &insert(ConstructorArg &&arg, ConstructorArgs &&...args) {
1500     // The following expands a call to emplace_back for each of the pattern
1501     // types 'Ts'. This magic is necessary due to a limitation in the places
1502     // that a parameter pack can be expanded in c++11.
1503     // FIXME: In c++17 this can be simplified by using 'fold expressions'.
1504     (void)std::initializer_list<int>{
1505         0, (addImpl<Ts>(/*debugLabels=*/llvm::None, arg, args...), 0)...};
1506     return *this;
1507   }
1508 
1509   /// Add an instance of each of the pattern types 'Ts'. Return a reference to
1510   /// `this` for chaining insertions.
1511   template <typename... Ts>
1512   RewritePatternSet &insert() {
1513     (void)std::initializer_list<int>{0, (addImpl<Ts>(), 0)...};
1514     return *this;
1515   }
1516 
1517   /// Add the given native pattern to the pattern list. Return a reference to
1518   /// `this` for chaining insertions.
1519   RewritePatternSet &insert(std::unique_ptr<RewritePattern> pattern) {
1520     nativePatterns.emplace_back(std::move(pattern));
1521     return *this;
1522   }
1523 
1524   /// Add the given PDL pattern to the pattern list. Return a reference to
1525   /// `this` for chaining insertions.
1526   RewritePatternSet &insert(PDLPatternModule &&pattern) {
1527     pdlPatterns.mergeIn(std::move(pattern));
1528     return *this;
1529   }
1530 
1531   // Add a matchAndRewrite style pattern represented as a C function pointer.
1532   template <typename OpType>
1533   RewritePatternSet &
1534   insert(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter)) {
1535     struct FnPattern final : public OpRewritePattern<OpType> {
1536       FnPattern(LogicalResult (*implFn)(OpType, PatternRewriter &rewriter),
1537                 MLIRContext *context)
1538           : OpRewritePattern<OpType>(context), implFn(implFn) {
1539         this->setDebugName(llvm::getTypeName<FnPattern>());
1540       }
1541 
1542       LogicalResult matchAndRewrite(OpType op,
1543                                     PatternRewriter &rewriter) const override {
1544         return implFn(op, rewriter);
1545       }
1546 
1547     private:
1548       LogicalResult (*implFn)(OpType, PatternRewriter &rewriter);
1549     };
1550     add(std::make_unique<FnPattern>(std::move(implFn), getContext()));
1551     return *this;
1552   }
1553 
1554 private:
1555   /// Add an instance of the pattern type 'T'. Return a reference to `this` for
1556   /// chaining insertions.
1557   template <typename T, typename... Args>
1558   std::enable_if_t<std::is_base_of<RewritePattern, T>::value>
1559   addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
1560     std::unique_ptr<T> pattern =
1561         RewritePattern::create<T>(std::forward<Args>(args)...);
1562     pattern->addDebugLabels(debugLabels);
1563     nativePatterns.emplace_back(std::move(pattern));
1564   }
1565   template <typename T, typename... Args>
1566   std::enable_if_t<std::is_base_of<PDLPatternModule, T>::value>
1567   addImpl(ArrayRef<StringRef> debugLabels, Args &&...args) {
1568     // TODO: Add the provided labels to the PDL pattern when PDL supports
1569     // labels.
1570     pdlPatterns.mergeIn(T(std::forward<Args>(args)...));
1571   }
1572 
1573   MLIRContext *const context;
1574   NativePatternListT nativePatterns;
1575   PDLPatternModule pdlPatterns;
1576 };
1577 
1578 } // namespace mlir
1579 
1580 #endif // MLIR_IR_PATTERNMATCH_H
1581