1 //===- Predicate.h - Pattern predicates -------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains definitions for "predicates" used when converting PDL into
10 // a matcher tree. Predicates are composed of three different parts:
11 //
12 //  * Positions
13 //    - A position refers to a specific location on the input DAG, i.e. an
14 //      existing MLIR entity being matched. These can be attributes, operands,
15 //      operations, results, and types. Each position also defines a relation to
16 //      its parent. For example, the operand `[0] -> 1` has a parent operation
17 //      position `[0]`. The attribute `[0, 1] -> "myAttr"` has parent operation
18 //      position of `[0, 1]`. The operation `[0, 1]` has a parent operand edge
19 //      `[0] -> 1` (i.e. it is the defining op of operand 1). The only position
20 //      without a parent is `[0]`, which refers to the root operation.
21 //  * Questions
22 //    - A question refers to a query on a specific positional value. For
23 //    example, an operation name question checks the name of an operation
24 //    position.
25 //  * Answers
26 //    - An answer is the expected result of a question. For example, when
27 //    matching an operation with the name "foo.op". The question would be an
28 //    operation name question, with an expected answer of "foo.op".
29 //
30 //===----------------------------------------------------------------------===//
31 
32 #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
33 #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
34 
35 #include "mlir/IR/MLIRContext.h"
36 #include "mlir/IR/OperationSupport.h"
37 #include "mlir/IR/PatternMatch.h"
38 #include "mlir/IR/Types.h"
39 
40 namespace mlir {
41 namespace pdl_to_pdl_interp {
42 namespace Predicates {
43 /// An enumeration of the kinds of predicates.
44 enum Kind : unsigned {
45   /// Positions, ordered by decreasing priority.
46   OperationPos,
47   OperandPos,
48   OperandGroupPos,
49   AttributePos,
50   ResultPos,
51   ResultGroupPos,
52   TypePos,
53   AttributeLiteralPos,
54   TypeLiteralPos,
55   UsersPos,
56   ForEachPos,
57 
58   // Questions, ordered by dependency and decreasing priority.
59   IsNotNullQuestion,
60   OperationNameQuestion,
61   TypeQuestion,
62   AttributeQuestion,
63   OperandCountAtLeastQuestion,
64   OperandCountQuestion,
65   ResultCountAtLeastQuestion,
66   ResultCountQuestion,
67   EqualToQuestion,
68   ConstraintQuestion,
69 
70   // Answers.
71   AttributeAnswer,
72   FalseAnswer,
73   OperationNameAnswer,
74   TrueAnswer,
75   TypeAnswer,
76   UnsignedAnswer,
77 };
78 } // namespace Predicates
79 
80 /// Base class for all predicates, used to allow efficient pointer comparison.
81 template <typename ConcreteT, typename BaseT, typename Key,
82           Predicates::Kind Kind>
83 class PredicateBase : public BaseT {
84 public:
85   using KeyTy = Key;
86   using Base = PredicateBase<ConcreteT, BaseT, Key, Kind>;
87 
88   template <typename KeyT>
PredicateBase(KeyT && key)89   explicit PredicateBase(KeyT &&key)
90       : BaseT(Kind), key(std::forward<KeyT>(key)) {}
91 
92   /// Get an instance of this position.
93   template <typename... Args>
get(StorageUniquer & uniquer,Args &&...args)94   static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) {
95     return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...);
96   }
97 
98   /// Construct an instance with the given storage allocator.
99   template <typename KeyT>
construct(StorageUniquer::StorageAllocator & alloc,KeyT && key)100   static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc,
101                               KeyT &&key) {
102     return new (alloc.allocate<ConcreteT>()) ConcreteT(std::forward<KeyT>(key));
103   }
104 
105   /// Utility methods required by the storage allocator.
106   bool operator==(const KeyTy &key) const { return this->key == key; }
classof(const BaseT * pred)107   static bool classof(const BaseT *pred) { return pred->getKind() == Kind; }
108 
109   /// Return the key value of this predicate.
getValue()110   const KeyTy &getValue() const { return key; }
111 
112 protected:
113   KeyTy key;
114 };
115 
116 /// Base storage for simple predicates that only unique with the kind.
117 template <typename ConcreteT, typename BaseT, Predicates::Kind Kind>
118 class PredicateBase<ConcreteT, BaseT, void, Kind> : public BaseT {
119 public:
120   using Base = PredicateBase<ConcreteT, BaseT, void, Kind>;
121 
PredicateBase()122   explicit PredicateBase() : BaseT(Kind) {}
123 
get(StorageUniquer & uniquer)124   static ConcreteT *get(StorageUniquer &uniquer) {
125     return uniquer.get<ConcreteT>();
126   }
classof(const BaseT * pred)127   static bool classof(const BaseT *pred) { return pred->getKind() == Kind; }
128 };
129 
130 //===----------------------------------------------------------------------===//
131 // Positions
132 //===----------------------------------------------------------------------===//
133 
134 struct OperationPosition;
135 
136 /// A position describes a value on the input IR on which a predicate may be
137 /// applied, such as an operation or attribute. This enables re-use between
138 /// predicates, and assists generating bytecode and memory management.
139 ///
140 /// Operation positions form the base of other positions, which are formed
141 /// relative to a parent operation. Operations are anchored at Operand nodes,
142 /// except for the root operation which is parentless.
143 class Position : public StorageUniquer::BaseStorage {
144 public:
Position(Predicates::Kind kind)145   explicit Position(Predicates::Kind kind) : kind(kind) {}
146   virtual ~Position();
147 
148   /// Returns the depth of the first ancestor operation position.
149   unsigned getOperationDepth() const;
150 
151   /// Returns the parent position. The root operation position has no parent.
getParent()152   Position *getParent() const { return parent; }
153 
154   /// Returns the kind of this position.
getKind()155   Predicates::Kind getKind() const { return kind; }
156 
157 protected:
158   /// Link to the parent position.
159   Position *parent = nullptr;
160 
161 private:
162   /// The kind of this position.
163   Predicates::Kind kind;
164 };
165 
166 //===----------------------------------------------------------------------===//
167 // AttributePosition
168 
169 /// A position describing an attribute of an operation.
170 struct AttributePosition
171     : public PredicateBase<AttributePosition, Position,
172                            std::pair<OperationPosition *, StringAttr>,
173                            Predicates::AttributePos> {
174   explicit AttributePosition(const KeyTy &key);
175 
176   /// Returns the attribute name of this position.
getNameAttributePosition177   StringAttr getName() const { return key.second; }
178 };
179 
180 //===----------------------------------------------------------------------===//
181 // AttributeLiteralPosition
182 
183 /// A position describing a literal attribute.
184 struct AttributeLiteralPosition
185     : public PredicateBase<AttributeLiteralPosition, Position, Attribute,
186                            Predicates::AttributeLiteralPos> {
187   using PredicateBase::PredicateBase;
188 };
189 
190 //===----------------------------------------------------------------------===//
191 // ForEachPosition
192 
193 /// A position describing an iterative choice of an operation.
194 struct ForEachPosition : public PredicateBase<ForEachPosition, Position,
195                                               std::pair<Position *, unsigned>,
196                                               Predicates::ForEachPos> {
ForEachPositionForEachPosition197   explicit ForEachPosition(const KeyTy &key) : Base(key) { parent = key.first; }
198 
199   /// Returns the ID, for differentiating various loops.
200   /// For upward traversals, this is the index of the root.
getIDForEachPosition201   unsigned getID() const { return key.second; }
202 };
203 
204 //===----------------------------------------------------------------------===//
205 // OperandPosition
206 
207 /// A position describing an operand of an operation.
208 struct OperandPosition
209     : public PredicateBase<OperandPosition, Position,
210                            std::pair<OperationPosition *, unsigned>,
211                            Predicates::OperandPos> {
212   explicit OperandPosition(const KeyTy &key);
213 
214   /// Returns the operand number of this position.
getOperandNumberOperandPosition215   unsigned getOperandNumber() const { return key.second; }
216 };
217 
218 //===----------------------------------------------------------------------===//
219 // OperandGroupPosition
220 
221 /// A position describing an operand group of an operation.
222 struct OperandGroupPosition
223     : public PredicateBase<
224           OperandGroupPosition, Position,
225           std::tuple<OperationPosition *, Optional<unsigned>, bool>,
226           Predicates::OperandGroupPos> {
227   explicit OperandGroupPosition(const KeyTy &key);
228 
229   /// Returns a hash suitable for the given keytype.
hashKeyOperandGroupPosition230   static llvm::hash_code hashKey(const KeyTy &key) {
231     return llvm::hash_value(key);
232   }
233 
234   /// Returns the group number of this position. If None, this group refers to
235   /// all operands.
getOperandGroupNumberOperandGroupPosition236   Optional<unsigned> getOperandGroupNumber() const { return std::get<1>(key); }
237 
238   /// Returns if the operand group has unknown size. If false, the operand group
239   /// has at max one element.
isVariadicOperandGroupPosition240   bool isVariadic() const { return std::get<2>(key); }
241 };
242 
243 //===----------------------------------------------------------------------===//
244 // OperationPosition
245 
246 /// An operation position describes an operation node in the IR. Other position
247 /// kinds are formed with respect to an operation position.
248 struct OperationPosition : public PredicateBase<OperationPosition, Position,
249                                                 std::pair<Position *, unsigned>,
250                                                 Predicates::OperationPos> {
OperationPositionOperationPosition251   explicit OperationPosition(const KeyTy &key) : Base(key) {
252     parent = key.first;
253   }
254 
255   /// Returns a hash suitable for the given keytype.
hashKeyOperationPosition256   static llvm::hash_code hashKey(const KeyTy &key) {
257     return llvm::hash_value(key);
258   }
259 
260   /// Gets the root position.
getRootOperationPosition261   static OperationPosition *getRoot(StorageUniquer &uniquer) {
262     return Base::get(uniquer, nullptr, 0);
263   }
264 
265   /// Gets an operation position with the given parent.
getOperationPosition266   static OperationPosition *get(StorageUniquer &uniquer, Position *parent) {
267     return Base::get(uniquer, parent, parent->getOperationDepth() + 1);
268   }
269 
270   /// Returns the depth of this position.
getDepthOperationPosition271   unsigned getDepth() const { return key.second; }
272 
273   /// Returns if this operation position corresponds to the root.
isRootOperationPosition274   bool isRoot() const { return getDepth() == 0; }
275 
276   /// Returns if this operation represents an operand defining op.
277   bool isOperandDefiningOp() const;
278 };
279 
280 //===----------------------------------------------------------------------===//
281 // ResultPosition
282 
283 /// A position describing a result of an operation.
284 struct ResultPosition
285     : public PredicateBase<ResultPosition, Position,
286                            std::pair<OperationPosition *, unsigned>,
287                            Predicates::ResultPos> {
ResultPositionResultPosition288   explicit ResultPosition(const KeyTy &key) : Base(key) { parent = key.first; }
289 
290   /// Returns the result number of this position.
getResultNumberResultPosition291   unsigned getResultNumber() const { return key.second; }
292 };
293 
294 //===----------------------------------------------------------------------===//
295 // ResultGroupPosition
296 
297 /// A position describing a result group of an operation.
298 struct ResultGroupPosition
299     : public PredicateBase<
300           ResultGroupPosition, Position,
301           std::tuple<OperationPosition *, Optional<unsigned>, bool>,
302           Predicates::ResultGroupPos> {
ResultGroupPositionResultGroupPosition303   explicit ResultGroupPosition(const KeyTy &key) : Base(key) {
304     parent = std::get<0>(key);
305   }
306 
307   /// Returns a hash suitable for the given keytype.
hashKeyResultGroupPosition308   static llvm::hash_code hashKey(const KeyTy &key) {
309     return llvm::hash_value(key);
310   }
311 
312   /// Returns the group number of this position. If None, this group refers to
313   /// all results.
getResultGroupNumberResultGroupPosition314   Optional<unsigned> getResultGroupNumber() const { return std::get<1>(key); }
315 
316   /// Returns if the result group has unknown size. If false, the result group
317   /// has at max one element.
isVariadicResultGroupPosition318   bool isVariadic() const { return std::get<2>(key); }
319 };
320 
321 //===----------------------------------------------------------------------===//
322 // TypePosition
323 
324 /// A position describing the result type of an entity, i.e. an Attribute,
325 /// Operand, Result, etc.
326 struct TypePosition : public PredicateBase<TypePosition, Position, Position *,
327                                            Predicates::TypePos> {
TypePositionTypePosition328   explicit TypePosition(const KeyTy &key) : Base(key) {
329     assert((isa<AttributePosition, OperandPosition, OperandGroupPosition,
330                 ResultPosition, ResultGroupPosition>(key)) &&
331            "expected parent to be an attribute, operand, or result");
332     parent = key;
333   }
334 };
335 
336 //===----------------------------------------------------------------------===//
337 // TypeLiteralPosition
338 
339 /// A position describing a literal type or type range. The value is stored as
340 /// either a TypeAttr, or an ArrayAttr of TypeAttr.
341 struct TypeLiteralPosition
342     : public PredicateBase<TypeLiteralPosition, Position, Attribute,
343                            Predicates::TypeLiteralPos> {
344   using PredicateBase::PredicateBase;
345 };
346 
347 //===----------------------------------------------------------------------===//
348 // UsersPosition
349 
350 /// A position describing the users of a value or a range of values. The second
351 /// value in the key indicates whether we choose users of a representative for
352 /// a range (this is true, e.g., in the upward traversals).
353 struct UsersPosition
354     : public PredicateBase<UsersPosition, Position, std::pair<Position *, bool>,
355                            Predicates::UsersPos> {
UsersPositionUsersPosition356   explicit UsersPosition(const KeyTy &key) : Base(key) { parent = key.first; }
357 
358   /// Returns a hash suitable for the given keytype.
hashKeyUsersPosition359   static llvm::hash_code hashKey(const KeyTy &key) {
360     return llvm::hash_value(key);
361   }
362 
363   /// Indicates whether to compute a range of a representative.
useRepresentativeUsersPosition364   bool useRepresentative() const { return key.second; }
365 };
366 
367 //===----------------------------------------------------------------------===//
368 // Qualifiers
369 //===----------------------------------------------------------------------===//
370 
371 /// An ordinal predicate consists of a "Question" and a set of acceptable
372 /// "Answers" (later converted to ordinal values). A predicate will query some
373 /// property of a positional value and decide what to do based on the result.
374 ///
375 /// This makes top-level predicate representations ordinal (SwitchOp). Later,
376 /// predicates that end up with only one acceptable answer (including all
377 /// boolean kinds) will be converted to boolean predicates (PredicateOp) in the
378 /// matcher.
379 ///
380 /// For simplicity, both are represented as "qualifiers", with a base kind and
381 /// perhaps additional properties. For example, all OperationName predicates ask
382 /// the same question, but GenericConstraint predicates may ask different ones.
383 class Qualifier : public StorageUniquer::BaseStorage {
384 public:
Qualifier(Predicates::Kind kind)385   explicit Qualifier(Predicates::Kind kind) : kind(kind) {}
386 
387   /// Returns the kind of this qualifier.
getKind()388   Predicates::Kind getKind() const { return kind; }
389 
390 private:
391   /// The kind of this position.
392   Predicates::Kind kind;
393 };
394 
395 //===----------------------------------------------------------------------===//
396 // Answers
397 
398 /// An Answer representing an `Attribute` value.
399 struct AttributeAnswer
400     : public PredicateBase<AttributeAnswer, Qualifier, Attribute,
401                            Predicates::AttributeAnswer> {
402   using Base::Base;
403 };
404 
405 /// An Answer representing an `OperationName` value.
406 struct OperationNameAnswer
407     : public PredicateBase<OperationNameAnswer, Qualifier, OperationName,
408                            Predicates::OperationNameAnswer> {
409   using Base::Base;
410 };
411 
412 /// An Answer representing a boolean `true` value.
413 struct TrueAnswer
414     : PredicateBase<TrueAnswer, Qualifier, void, Predicates::TrueAnswer> {
415   using Base::Base;
416 };
417 
418 /// An Answer representing a boolean 'false' value.
419 struct FalseAnswer
420     : PredicateBase<FalseAnswer, Qualifier, void, Predicates::FalseAnswer> {
421   using Base::Base;
422 };
423 
424 /// An Answer representing a `Type` value. The value is stored as either a
425 /// TypeAttr, or an ArrayAttr of TypeAttr.
426 struct TypeAnswer : public PredicateBase<TypeAnswer, Qualifier, Attribute,
427                                          Predicates::TypeAnswer> {
428   using Base::Base;
429 };
430 
431 /// An Answer representing an unsigned value.
432 struct UnsignedAnswer
433     : public PredicateBase<UnsignedAnswer, Qualifier, unsigned,
434                            Predicates::UnsignedAnswer> {
435   using Base::Base;
436 };
437 
438 //===----------------------------------------------------------------------===//
439 // Questions
440 
441 /// Compare an `Attribute` to a constant value.
442 struct AttributeQuestion
443     : public PredicateBase<AttributeQuestion, Qualifier, void,
444                            Predicates::AttributeQuestion> {};
445 
446 /// Apply a parameterized constraint to multiple position values.
447 struct ConstraintQuestion
448     : public PredicateBase<ConstraintQuestion, Qualifier,
449                            std::tuple<StringRef, ArrayRef<Position *>>,
450                            Predicates::ConstraintQuestion> {
451   using Base::Base;
452 
453   /// Return the name of the constraint.
getNameConstraintQuestion454   StringRef getName() const { return std::get<0>(key); }
455 
456   /// Return the arguments of the constraint.
getArgsConstraintQuestion457   ArrayRef<Position *> getArgs() const { return std::get<1>(key); }
458 
459   /// Construct an instance with the given storage allocator.
constructConstraintQuestion460   static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc,
461                                        KeyTy key) {
462     return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)),
463                                         alloc.copyInto(std::get<1>(key))});
464   }
465 };
466 
467 /// Compare the equality of two values.
468 struct EqualToQuestion
469     : public PredicateBase<EqualToQuestion, Qualifier, Position *,
470                            Predicates::EqualToQuestion> {
471   using Base::Base;
472 };
473 
474 /// Compare a positional value with null, i.e. check if it exists.
475 struct IsNotNullQuestion
476     : public PredicateBase<IsNotNullQuestion, Qualifier, void,
477                            Predicates::IsNotNullQuestion> {};
478 
479 /// Compare the number of operands of an operation with a known value.
480 struct OperandCountQuestion
481     : public PredicateBase<OperandCountQuestion, Qualifier, void,
482                            Predicates::OperandCountQuestion> {};
483 struct OperandCountAtLeastQuestion
484     : public PredicateBase<OperandCountAtLeastQuestion, Qualifier, void,
485                            Predicates::OperandCountAtLeastQuestion> {};
486 
487 /// Compare the name of an operation with a known value.
488 struct OperationNameQuestion
489     : public PredicateBase<OperationNameQuestion, Qualifier, void,
490                            Predicates::OperationNameQuestion> {};
491 
492 /// Compare the number of results of an operation with a known value.
493 struct ResultCountQuestion
494     : public PredicateBase<ResultCountQuestion, Qualifier, void,
495                            Predicates::ResultCountQuestion> {};
496 struct ResultCountAtLeastQuestion
497     : public PredicateBase<ResultCountAtLeastQuestion, Qualifier, void,
498                            Predicates::ResultCountAtLeastQuestion> {};
499 
500 /// Compare the type of an attribute or value with a known type.
501 struct TypeQuestion : public PredicateBase<TypeQuestion, Qualifier, void,
502                                            Predicates::TypeQuestion> {};
503 
504 //===----------------------------------------------------------------------===//
505 // PredicateUniquer
506 //===----------------------------------------------------------------------===//
507 
508 /// This class provides a storage uniquer that is used to allocate predicate
509 /// instances.
510 class PredicateUniquer : public StorageUniquer {
511 public:
PredicateUniquer()512   PredicateUniquer() {
513     // Register the types of Positions with the uniquer.
514     registerParametricStorageType<AttributePosition>();
515     registerParametricStorageType<AttributeLiteralPosition>();
516     registerParametricStorageType<ForEachPosition>();
517     registerParametricStorageType<OperandPosition>();
518     registerParametricStorageType<OperandGroupPosition>();
519     registerParametricStorageType<OperationPosition>();
520     registerParametricStorageType<ResultPosition>();
521     registerParametricStorageType<ResultGroupPosition>();
522     registerParametricStorageType<TypePosition>();
523     registerParametricStorageType<TypeLiteralPosition>();
524     registerParametricStorageType<UsersPosition>();
525 
526     // Register the types of Questions with the uniquer.
527     registerParametricStorageType<AttributeAnswer>();
528     registerParametricStorageType<OperationNameAnswer>();
529     registerParametricStorageType<TypeAnswer>();
530     registerParametricStorageType<UnsignedAnswer>();
531     registerSingletonStorageType<FalseAnswer>();
532     registerSingletonStorageType<TrueAnswer>();
533 
534     // Register the types of Answers with the uniquer.
535     registerParametricStorageType<ConstraintQuestion>();
536     registerParametricStorageType<EqualToQuestion>();
537     registerSingletonStorageType<AttributeQuestion>();
538     registerSingletonStorageType<IsNotNullQuestion>();
539     registerSingletonStorageType<OperandCountQuestion>();
540     registerSingletonStorageType<OperandCountAtLeastQuestion>();
541     registerSingletonStorageType<OperationNameQuestion>();
542     registerSingletonStorageType<ResultCountQuestion>();
543     registerSingletonStorageType<ResultCountAtLeastQuestion>();
544     registerSingletonStorageType<TypeQuestion>();
545   }
546 };
547 
548 //===----------------------------------------------------------------------===//
549 // PredicateBuilder
550 //===----------------------------------------------------------------------===//
551 
552 /// This class provides utilities for constructing predicates.
553 class PredicateBuilder {
554 public:
PredicateBuilder(PredicateUniquer & uniquer,MLIRContext * ctx)555   PredicateBuilder(PredicateUniquer &uniquer, MLIRContext *ctx)
556       : uniquer(uniquer), ctx(ctx) {}
557 
558   //===--------------------------------------------------------------------===//
559   // Positions
560   //===--------------------------------------------------------------------===//
561 
562   /// Returns the root operation position.
getRoot()563   Position *getRoot() { return OperationPosition::getRoot(uniquer); }
564 
565   /// Returns the parent position defining the value held by the given operand.
getOperandDefiningOp(Position * p)566   OperationPosition *getOperandDefiningOp(Position *p) {
567     assert((isa<OperandPosition, OperandGroupPosition>(p)) &&
568            "expected operand position");
569     return OperationPosition::get(uniquer, p);
570   }
571 
572   /// Returns the operation position equivalent to the given position.
getPassthroughOp(Position * p)573   OperationPosition *getPassthroughOp(Position *p) {
574     assert((isa<ForEachPosition>(p)) && "expected users position");
575     return OperationPosition::get(uniquer, p);
576   }
577 
578   /// Returns an attribute position for an attribute of the given operation.
getAttribute(OperationPosition * p,StringRef name)579   Position *getAttribute(OperationPosition *p, StringRef name) {
580     return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name));
581   }
582 
583   /// Returns an attribute position for the given attribute.
getAttributeLiteral(Attribute attr)584   Position *getAttributeLiteral(Attribute attr) {
585     return AttributeLiteralPosition::get(uniquer, attr);
586   }
587 
getForEach(Position * p,unsigned id)588   Position *getForEach(Position *p, unsigned id) {
589     return ForEachPosition::get(uniquer, p, id);
590   }
591 
592   /// Returns an operand position for an operand of the given operation.
getOperand(OperationPosition * p,unsigned operand)593   Position *getOperand(OperationPosition *p, unsigned operand) {
594     return OperandPosition::get(uniquer, p, operand);
595   }
596 
597   /// Returns a position for a group of operands of the given operation.
getOperandGroup(OperationPosition * p,Optional<unsigned> group,bool isVariadic)598   Position *getOperandGroup(OperationPosition *p, Optional<unsigned> group,
599                             bool isVariadic) {
600     return OperandGroupPosition::get(uniquer, p, group, isVariadic);
601   }
getAllOperands(OperationPosition * p)602   Position *getAllOperands(OperationPosition *p) {
603     return getOperandGroup(p, /*group=*/llvm::None, /*isVariadic=*/true);
604   }
605 
606   /// Returns a result position for a result of the given operation.
getResult(OperationPosition * p,unsigned result)607   Position *getResult(OperationPosition *p, unsigned result) {
608     return ResultPosition::get(uniquer, p, result);
609   }
610 
611   /// Returns a position for a group of results of the given operation.
getResultGroup(OperationPosition * p,Optional<unsigned> group,bool isVariadic)612   Position *getResultGroup(OperationPosition *p, Optional<unsigned> group,
613                            bool isVariadic) {
614     return ResultGroupPosition::get(uniquer, p, group, isVariadic);
615   }
getAllResults(OperationPosition * p)616   Position *getAllResults(OperationPosition *p) {
617     return getResultGroup(p, /*group=*/llvm::None, /*isVariadic=*/true);
618   }
619 
620   /// Returns a type position for the given entity.
getType(Position * p)621   Position *getType(Position *p) { return TypePosition::get(uniquer, p); }
622 
623   /// Returns a type position for the given type value. The value is stored
624   /// as either a TypeAttr, or an ArrayAttr of TypeAttr.
getTypeLiteral(Attribute attr)625   Position *getTypeLiteral(Attribute attr) {
626     return TypeLiteralPosition::get(uniquer, attr);
627   }
628 
629   /// Returns the users of a position using the value at the given operand.
getUsers(Position * p,bool useRepresentative)630   UsersPosition *getUsers(Position *p, bool useRepresentative) {
631     assert((isa<OperandPosition, OperandGroupPosition, ResultPosition,
632                 ResultGroupPosition>(p)) &&
633            "expected result position");
634     return UsersPosition::get(uniquer, p, useRepresentative);
635   }
636 
637   //===--------------------------------------------------------------------===//
638   // Qualifiers
639   //===--------------------------------------------------------------------===//
640 
641   /// An ordinal predicate consists of a "Question" and a set of acceptable
642   /// "Answers" (later converted to ordinal values). A predicate will query some
643   /// property of a positional value and decide what to do based on the result.
644   using Predicate = std::pair<Qualifier *, Qualifier *>;
645 
646   /// Create a predicate comparing an attribute to a known value.
getAttributeConstraint(Attribute attr)647   Predicate getAttributeConstraint(Attribute attr) {
648     return {AttributeQuestion::get(uniquer),
649             AttributeAnswer::get(uniquer, attr)};
650   }
651 
652   /// Create a predicate checking if two values are equal.
getEqualTo(Position * pos)653   Predicate getEqualTo(Position *pos) {
654     return {EqualToQuestion::get(uniquer, pos), TrueAnswer::get(uniquer)};
655   }
656 
657   /// Create a predicate checking if two values are not equal.
getNotEqualTo(Position * pos)658   Predicate getNotEqualTo(Position *pos) {
659     return {EqualToQuestion::get(uniquer, pos), FalseAnswer::get(uniquer)};
660   }
661 
662   /// Create a predicate that applies a generic constraint.
getConstraint(StringRef name,ArrayRef<Position * > pos)663   Predicate getConstraint(StringRef name, ArrayRef<Position *> pos) {
664     return {ConstraintQuestion::get(uniquer, std::make_tuple(name, pos)),
665             TrueAnswer::get(uniquer)};
666   }
667 
668   /// Create a predicate comparing a value with null.
getIsNotNull()669   Predicate getIsNotNull() {
670     return {IsNotNullQuestion::get(uniquer), TrueAnswer::get(uniquer)};
671   }
672 
673   /// Create a predicate comparing the number of operands of an operation to a
674   /// known value.
getOperandCount(unsigned count)675   Predicate getOperandCount(unsigned count) {
676     return {OperandCountQuestion::get(uniquer),
677             UnsignedAnswer::get(uniquer, count)};
678   }
getOperandCountAtLeast(unsigned count)679   Predicate getOperandCountAtLeast(unsigned count) {
680     return {OperandCountAtLeastQuestion::get(uniquer),
681             UnsignedAnswer::get(uniquer, count)};
682   }
683 
684   /// Create a predicate comparing the name of an operation to a known value.
getOperationName(StringRef name)685   Predicate getOperationName(StringRef name) {
686     return {OperationNameQuestion::get(uniquer),
687             OperationNameAnswer::get(uniquer, OperationName(name, ctx))};
688   }
689 
690   /// Create a predicate comparing the number of results of an operation to a
691   /// known value.
getResultCount(unsigned count)692   Predicate getResultCount(unsigned count) {
693     return {ResultCountQuestion::get(uniquer),
694             UnsignedAnswer::get(uniquer, count)};
695   }
getResultCountAtLeast(unsigned count)696   Predicate getResultCountAtLeast(unsigned count) {
697     return {ResultCountAtLeastQuestion::get(uniquer),
698             UnsignedAnswer::get(uniquer, count)};
699   }
700 
701   /// Create a predicate comparing the type of an attribute or value to a known
702   /// type. The value is stored as either a TypeAttr, or an ArrayAttr of
703   /// TypeAttr.
getTypeConstraint(Attribute type)704   Predicate getTypeConstraint(Attribute type) {
705     return {TypeQuestion::get(uniquer), TypeAnswer::get(uniquer, type)};
706   }
707 
708 private:
709   /// The uniquer used when allocating predicate nodes.
710   PredicateUniquer &uniquer;
711 
712   /// The current MLIR context.
713   MLIRContext *ctx;
714 };
715 
716 } // namespace pdl_to_pdl_interp
717 } // namespace mlir
718 
719 #endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_
720