1 //===- Predicate.h - Pattern predicates -------------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains definitions for "predicates" used when converting PDL into 10 // a matcher tree. Predicates are composed of three different parts: 11 // 12 // * Positions 13 // - A position refers to a specific location on the input DAG, i.e. an 14 // existing MLIR entity being matched. These can be attributes, operands, 15 // operations, results, and types. Each position also defines a relation to 16 // its parent. For example, the operand `[0] -> 1` has a parent operation 17 // position `[0]`. The attribute `[0, 1] -> "myAttr"` has parent operation 18 // position of `[0, 1]`. The operation `[0, 1]` has a parent operand edge 19 // `[0] -> 1` (i.e. it is the defining op of operand 1). The only position 20 // without a parent is `[0]`, which refers to the root operation. 21 // * Questions 22 // - A question refers to a query on a specific positional value. For 23 // example, an operation name question checks the name of an operation 24 // position. 25 // * Answers 26 // - An answer is the expected result of a question. For example, when 27 // matching an operation with the name "foo.op". The question would be an 28 // operation name question, with an expected answer of "foo.op". 29 // 30 //===----------------------------------------------------------------------===// 31 32 #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ 33 #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ 34 35 #include "mlir/IR/MLIRContext.h" 36 #include "mlir/IR/OperationSupport.h" 37 #include "mlir/IR/PatternMatch.h" 38 #include "mlir/IR/Types.h" 39 40 namespace mlir { 41 namespace pdl_to_pdl_interp { 42 namespace Predicates { 43 /// An enumeration of the kinds of predicates. 44 enum Kind : unsigned { 45 /// Positions, ordered by decreasing priority. 46 OperationPos, 47 OperandPos, 48 OperandGroupPos, 49 AttributePos, 50 ResultPos, 51 ResultGroupPos, 52 TypePos, 53 AttributeLiteralPos, 54 TypeLiteralPos, 55 UsersPos, 56 ForEachPos, 57 58 // Questions, ordered by dependency and decreasing priority. 59 IsNotNullQuestion, 60 OperationNameQuestion, 61 TypeQuestion, 62 AttributeQuestion, 63 OperandCountAtLeastQuestion, 64 OperandCountQuestion, 65 ResultCountAtLeastQuestion, 66 ResultCountQuestion, 67 EqualToQuestion, 68 ConstraintQuestion, 69 70 // Answers. 71 AttributeAnswer, 72 FalseAnswer, 73 OperationNameAnswer, 74 TrueAnswer, 75 TypeAnswer, 76 UnsignedAnswer, 77 }; 78 } // namespace Predicates 79 80 /// Base class for all predicates, used to allow efficient pointer comparison. 81 template <typename ConcreteT, typename BaseT, typename Key, 82 Predicates::Kind Kind> 83 class PredicateBase : public BaseT { 84 public: 85 using KeyTy = Key; 86 using Base = PredicateBase<ConcreteT, BaseT, Key, Kind>; 87 88 template <typename KeyT> PredicateBase(KeyT && key)89 explicit PredicateBase(KeyT &&key) 90 : BaseT(Kind), key(std::forward<KeyT>(key)) {} 91 92 /// Get an instance of this position. 93 template <typename... Args> get(StorageUniquer & uniquer,Args &&...args)94 static ConcreteT *get(StorageUniquer &uniquer, Args &&...args) { 95 return uniquer.get<ConcreteT>(/*initFn=*/{}, std::forward<Args>(args)...); 96 } 97 98 /// Construct an instance with the given storage allocator. 99 template <typename KeyT> construct(StorageUniquer::StorageAllocator & alloc,KeyT && key)100 static ConcreteT *construct(StorageUniquer::StorageAllocator &alloc, 101 KeyT &&key) { 102 return new (alloc.allocate<ConcreteT>()) ConcreteT(std::forward<KeyT>(key)); 103 } 104 105 /// Utility methods required by the storage allocator. 106 bool operator==(const KeyTy &key) const { return this->key == key; } classof(const BaseT * pred)107 static bool classof(const BaseT *pred) { return pred->getKind() == Kind; } 108 109 /// Return the key value of this predicate. getValue()110 const KeyTy &getValue() const { return key; } 111 112 protected: 113 KeyTy key; 114 }; 115 116 /// Base storage for simple predicates that only unique with the kind. 117 template <typename ConcreteT, typename BaseT, Predicates::Kind Kind> 118 class PredicateBase<ConcreteT, BaseT, void, Kind> : public BaseT { 119 public: 120 using Base = PredicateBase<ConcreteT, BaseT, void, Kind>; 121 PredicateBase()122 explicit PredicateBase() : BaseT(Kind) {} 123 get(StorageUniquer & uniquer)124 static ConcreteT *get(StorageUniquer &uniquer) { 125 return uniquer.get<ConcreteT>(); 126 } classof(const BaseT * pred)127 static bool classof(const BaseT *pred) { return pred->getKind() == Kind; } 128 }; 129 130 //===----------------------------------------------------------------------===// 131 // Positions 132 //===----------------------------------------------------------------------===// 133 134 struct OperationPosition; 135 136 /// A position describes a value on the input IR on which a predicate may be 137 /// applied, such as an operation or attribute. This enables re-use between 138 /// predicates, and assists generating bytecode and memory management. 139 /// 140 /// Operation positions form the base of other positions, which are formed 141 /// relative to a parent operation. Operations are anchored at Operand nodes, 142 /// except for the root operation which is parentless. 143 class Position : public StorageUniquer::BaseStorage { 144 public: Position(Predicates::Kind kind)145 explicit Position(Predicates::Kind kind) : kind(kind) {} 146 virtual ~Position(); 147 148 /// Returns the depth of the first ancestor operation position. 149 unsigned getOperationDepth() const; 150 151 /// Returns the parent position. The root operation position has no parent. getParent()152 Position *getParent() const { return parent; } 153 154 /// Returns the kind of this position. getKind()155 Predicates::Kind getKind() const { return kind; } 156 157 protected: 158 /// Link to the parent position. 159 Position *parent = nullptr; 160 161 private: 162 /// The kind of this position. 163 Predicates::Kind kind; 164 }; 165 166 //===----------------------------------------------------------------------===// 167 // AttributePosition 168 169 /// A position describing an attribute of an operation. 170 struct AttributePosition 171 : public PredicateBase<AttributePosition, Position, 172 std::pair<OperationPosition *, StringAttr>, 173 Predicates::AttributePos> { 174 explicit AttributePosition(const KeyTy &key); 175 176 /// Returns the attribute name of this position. getNameAttributePosition177 StringAttr getName() const { return key.second; } 178 }; 179 180 //===----------------------------------------------------------------------===// 181 // AttributeLiteralPosition 182 183 /// A position describing a literal attribute. 184 struct AttributeLiteralPosition 185 : public PredicateBase<AttributeLiteralPosition, Position, Attribute, 186 Predicates::AttributeLiteralPos> { 187 using PredicateBase::PredicateBase; 188 }; 189 190 //===----------------------------------------------------------------------===// 191 // ForEachPosition 192 193 /// A position describing an iterative choice of an operation. 194 struct ForEachPosition : public PredicateBase<ForEachPosition, Position, 195 std::pair<Position *, unsigned>, 196 Predicates::ForEachPos> { ForEachPositionForEachPosition197 explicit ForEachPosition(const KeyTy &key) : Base(key) { parent = key.first; } 198 199 /// Returns the ID, for differentiating various loops. 200 /// For upward traversals, this is the index of the root. getIDForEachPosition201 unsigned getID() const { return key.second; } 202 }; 203 204 //===----------------------------------------------------------------------===// 205 // OperandPosition 206 207 /// A position describing an operand of an operation. 208 struct OperandPosition 209 : public PredicateBase<OperandPosition, Position, 210 std::pair<OperationPosition *, unsigned>, 211 Predicates::OperandPos> { 212 explicit OperandPosition(const KeyTy &key); 213 214 /// Returns the operand number of this position. getOperandNumberOperandPosition215 unsigned getOperandNumber() const { return key.second; } 216 }; 217 218 //===----------------------------------------------------------------------===// 219 // OperandGroupPosition 220 221 /// A position describing an operand group of an operation. 222 struct OperandGroupPosition 223 : public PredicateBase< 224 OperandGroupPosition, Position, 225 std::tuple<OperationPosition *, Optional<unsigned>, bool>, 226 Predicates::OperandGroupPos> { 227 explicit OperandGroupPosition(const KeyTy &key); 228 229 /// Returns a hash suitable for the given keytype. hashKeyOperandGroupPosition230 static llvm::hash_code hashKey(const KeyTy &key) { 231 return llvm::hash_value(key); 232 } 233 234 /// Returns the group number of this position. If None, this group refers to 235 /// all operands. getOperandGroupNumberOperandGroupPosition236 Optional<unsigned> getOperandGroupNumber() const { return std::get<1>(key); } 237 238 /// Returns if the operand group has unknown size. If false, the operand group 239 /// has at max one element. isVariadicOperandGroupPosition240 bool isVariadic() const { return std::get<2>(key); } 241 }; 242 243 //===----------------------------------------------------------------------===// 244 // OperationPosition 245 246 /// An operation position describes an operation node in the IR. Other position 247 /// kinds are formed with respect to an operation position. 248 struct OperationPosition : public PredicateBase<OperationPosition, Position, 249 std::pair<Position *, unsigned>, 250 Predicates::OperationPos> { OperationPositionOperationPosition251 explicit OperationPosition(const KeyTy &key) : Base(key) { 252 parent = key.first; 253 } 254 255 /// Returns a hash suitable for the given keytype. hashKeyOperationPosition256 static llvm::hash_code hashKey(const KeyTy &key) { 257 return llvm::hash_value(key); 258 } 259 260 /// Gets the root position. getRootOperationPosition261 static OperationPosition *getRoot(StorageUniquer &uniquer) { 262 return Base::get(uniquer, nullptr, 0); 263 } 264 265 /// Gets an operation position with the given parent. getOperationPosition266 static OperationPosition *get(StorageUniquer &uniquer, Position *parent) { 267 return Base::get(uniquer, parent, parent->getOperationDepth() + 1); 268 } 269 270 /// Returns the depth of this position. getDepthOperationPosition271 unsigned getDepth() const { return key.second; } 272 273 /// Returns if this operation position corresponds to the root. isRootOperationPosition274 bool isRoot() const { return getDepth() == 0; } 275 276 /// Returns if this operation represents an operand defining op. 277 bool isOperandDefiningOp() const; 278 }; 279 280 //===----------------------------------------------------------------------===// 281 // ResultPosition 282 283 /// A position describing a result of an operation. 284 struct ResultPosition 285 : public PredicateBase<ResultPosition, Position, 286 std::pair<OperationPosition *, unsigned>, 287 Predicates::ResultPos> { ResultPositionResultPosition288 explicit ResultPosition(const KeyTy &key) : Base(key) { parent = key.first; } 289 290 /// Returns the result number of this position. getResultNumberResultPosition291 unsigned getResultNumber() const { return key.second; } 292 }; 293 294 //===----------------------------------------------------------------------===// 295 // ResultGroupPosition 296 297 /// A position describing a result group of an operation. 298 struct ResultGroupPosition 299 : public PredicateBase< 300 ResultGroupPosition, Position, 301 std::tuple<OperationPosition *, Optional<unsigned>, bool>, 302 Predicates::ResultGroupPos> { ResultGroupPositionResultGroupPosition303 explicit ResultGroupPosition(const KeyTy &key) : Base(key) { 304 parent = std::get<0>(key); 305 } 306 307 /// Returns a hash suitable for the given keytype. hashKeyResultGroupPosition308 static llvm::hash_code hashKey(const KeyTy &key) { 309 return llvm::hash_value(key); 310 } 311 312 /// Returns the group number of this position. If None, this group refers to 313 /// all results. getResultGroupNumberResultGroupPosition314 Optional<unsigned> getResultGroupNumber() const { return std::get<1>(key); } 315 316 /// Returns if the result group has unknown size. If false, the result group 317 /// has at max one element. isVariadicResultGroupPosition318 bool isVariadic() const { return std::get<2>(key); } 319 }; 320 321 //===----------------------------------------------------------------------===// 322 // TypePosition 323 324 /// A position describing the result type of an entity, i.e. an Attribute, 325 /// Operand, Result, etc. 326 struct TypePosition : public PredicateBase<TypePosition, Position, Position *, 327 Predicates::TypePos> { TypePositionTypePosition328 explicit TypePosition(const KeyTy &key) : Base(key) { 329 assert((isa<AttributePosition, OperandPosition, OperandGroupPosition, 330 ResultPosition, ResultGroupPosition>(key)) && 331 "expected parent to be an attribute, operand, or result"); 332 parent = key; 333 } 334 }; 335 336 //===----------------------------------------------------------------------===// 337 // TypeLiteralPosition 338 339 /// A position describing a literal type or type range. The value is stored as 340 /// either a TypeAttr, or an ArrayAttr of TypeAttr. 341 struct TypeLiteralPosition 342 : public PredicateBase<TypeLiteralPosition, Position, Attribute, 343 Predicates::TypeLiteralPos> { 344 using PredicateBase::PredicateBase; 345 }; 346 347 //===----------------------------------------------------------------------===// 348 // UsersPosition 349 350 /// A position describing the users of a value or a range of values. The second 351 /// value in the key indicates whether we choose users of a representative for 352 /// a range (this is true, e.g., in the upward traversals). 353 struct UsersPosition 354 : public PredicateBase<UsersPosition, Position, std::pair<Position *, bool>, 355 Predicates::UsersPos> { UsersPositionUsersPosition356 explicit UsersPosition(const KeyTy &key) : Base(key) { parent = key.first; } 357 358 /// Returns a hash suitable for the given keytype. hashKeyUsersPosition359 static llvm::hash_code hashKey(const KeyTy &key) { 360 return llvm::hash_value(key); 361 } 362 363 /// Indicates whether to compute a range of a representative. useRepresentativeUsersPosition364 bool useRepresentative() const { return key.second; } 365 }; 366 367 //===----------------------------------------------------------------------===// 368 // Qualifiers 369 //===----------------------------------------------------------------------===// 370 371 /// An ordinal predicate consists of a "Question" and a set of acceptable 372 /// "Answers" (later converted to ordinal values). A predicate will query some 373 /// property of a positional value and decide what to do based on the result. 374 /// 375 /// This makes top-level predicate representations ordinal (SwitchOp). Later, 376 /// predicates that end up with only one acceptable answer (including all 377 /// boolean kinds) will be converted to boolean predicates (PredicateOp) in the 378 /// matcher. 379 /// 380 /// For simplicity, both are represented as "qualifiers", with a base kind and 381 /// perhaps additional properties. For example, all OperationName predicates ask 382 /// the same question, but GenericConstraint predicates may ask different ones. 383 class Qualifier : public StorageUniquer::BaseStorage { 384 public: Qualifier(Predicates::Kind kind)385 explicit Qualifier(Predicates::Kind kind) : kind(kind) {} 386 387 /// Returns the kind of this qualifier. getKind()388 Predicates::Kind getKind() const { return kind; } 389 390 private: 391 /// The kind of this position. 392 Predicates::Kind kind; 393 }; 394 395 //===----------------------------------------------------------------------===// 396 // Answers 397 398 /// An Answer representing an `Attribute` value. 399 struct AttributeAnswer 400 : public PredicateBase<AttributeAnswer, Qualifier, Attribute, 401 Predicates::AttributeAnswer> { 402 using Base::Base; 403 }; 404 405 /// An Answer representing an `OperationName` value. 406 struct OperationNameAnswer 407 : public PredicateBase<OperationNameAnswer, Qualifier, OperationName, 408 Predicates::OperationNameAnswer> { 409 using Base::Base; 410 }; 411 412 /// An Answer representing a boolean `true` value. 413 struct TrueAnswer 414 : PredicateBase<TrueAnswer, Qualifier, void, Predicates::TrueAnswer> { 415 using Base::Base; 416 }; 417 418 /// An Answer representing a boolean 'false' value. 419 struct FalseAnswer 420 : PredicateBase<FalseAnswer, Qualifier, void, Predicates::FalseAnswer> { 421 using Base::Base; 422 }; 423 424 /// An Answer representing a `Type` value. The value is stored as either a 425 /// TypeAttr, or an ArrayAttr of TypeAttr. 426 struct TypeAnswer : public PredicateBase<TypeAnswer, Qualifier, Attribute, 427 Predicates::TypeAnswer> { 428 using Base::Base; 429 }; 430 431 /// An Answer representing an unsigned value. 432 struct UnsignedAnswer 433 : public PredicateBase<UnsignedAnswer, Qualifier, unsigned, 434 Predicates::UnsignedAnswer> { 435 using Base::Base; 436 }; 437 438 //===----------------------------------------------------------------------===// 439 // Questions 440 441 /// Compare an `Attribute` to a constant value. 442 struct AttributeQuestion 443 : public PredicateBase<AttributeQuestion, Qualifier, void, 444 Predicates::AttributeQuestion> {}; 445 446 /// Apply a parameterized constraint to multiple position values. 447 struct ConstraintQuestion 448 : public PredicateBase<ConstraintQuestion, Qualifier, 449 std::tuple<StringRef, ArrayRef<Position *>>, 450 Predicates::ConstraintQuestion> { 451 using Base::Base; 452 453 /// Return the name of the constraint. getNameConstraintQuestion454 StringRef getName() const { return std::get<0>(key); } 455 456 /// Return the arguments of the constraint. getArgsConstraintQuestion457 ArrayRef<Position *> getArgs() const { return std::get<1>(key); } 458 459 /// Construct an instance with the given storage allocator. constructConstraintQuestion460 static ConstraintQuestion *construct(StorageUniquer::StorageAllocator &alloc, 461 KeyTy key) { 462 return Base::construct(alloc, KeyTy{alloc.copyInto(std::get<0>(key)), 463 alloc.copyInto(std::get<1>(key))}); 464 } 465 }; 466 467 /// Compare the equality of two values. 468 struct EqualToQuestion 469 : public PredicateBase<EqualToQuestion, Qualifier, Position *, 470 Predicates::EqualToQuestion> { 471 using Base::Base; 472 }; 473 474 /// Compare a positional value with null, i.e. check if it exists. 475 struct IsNotNullQuestion 476 : public PredicateBase<IsNotNullQuestion, Qualifier, void, 477 Predicates::IsNotNullQuestion> {}; 478 479 /// Compare the number of operands of an operation with a known value. 480 struct OperandCountQuestion 481 : public PredicateBase<OperandCountQuestion, Qualifier, void, 482 Predicates::OperandCountQuestion> {}; 483 struct OperandCountAtLeastQuestion 484 : public PredicateBase<OperandCountAtLeastQuestion, Qualifier, void, 485 Predicates::OperandCountAtLeastQuestion> {}; 486 487 /// Compare the name of an operation with a known value. 488 struct OperationNameQuestion 489 : public PredicateBase<OperationNameQuestion, Qualifier, void, 490 Predicates::OperationNameQuestion> {}; 491 492 /// Compare the number of results of an operation with a known value. 493 struct ResultCountQuestion 494 : public PredicateBase<ResultCountQuestion, Qualifier, void, 495 Predicates::ResultCountQuestion> {}; 496 struct ResultCountAtLeastQuestion 497 : public PredicateBase<ResultCountAtLeastQuestion, Qualifier, void, 498 Predicates::ResultCountAtLeastQuestion> {}; 499 500 /// Compare the type of an attribute or value with a known type. 501 struct TypeQuestion : public PredicateBase<TypeQuestion, Qualifier, void, 502 Predicates::TypeQuestion> {}; 503 504 //===----------------------------------------------------------------------===// 505 // PredicateUniquer 506 //===----------------------------------------------------------------------===// 507 508 /// This class provides a storage uniquer that is used to allocate predicate 509 /// instances. 510 class PredicateUniquer : public StorageUniquer { 511 public: PredicateUniquer()512 PredicateUniquer() { 513 // Register the types of Positions with the uniquer. 514 registerParametricStorageType<AttributePosition>(); 515 registerParametricStorageType<AttributeLiteralPosition>(); 516 registerParametricStorageType<ForEachPosition>(); 517 registerParametricStorageType<OperandPosition>(); 518 registerParametricStorageType<OperandGroupPosition>(); 519 registerParametricStorageType<OperationPosition>(); 520 registerParametricStorageType<ResultPosition>(); 521 registerParametricStorageType<ResultGroupPosition>(); 522 registerParametricStorageType<TypePosition>(); 523 registerParametricStorageType<TypeLiteralPosition>(); 524 registerParametricStorageType<UsersPosition>(); 525 526 // Register the types of Questions with the uniquer. 527 registerParametricStorageType<AttributeAnswer>(); 528 registerParametricStorageType<OperationNameAnswer>(); 529 registerParametricStorageType<TypeAnswer>(); 530 registerParametricStorageType<UnsignedAnswer>(); 531 registerSingletonStorageType<FalseAnswer>(); 532 registerSingletonStorageType<TrueAnswer>(); 533 534 // Register the types of Answers with the uniquer. 535 registerParametricStorageType<ConstraintQuestion>(); 536 registerParametricStorageType<EqualToQuestion>(); 537 registerSingletonStorageType<AttributeQuestion>(); 538 registerSingletonStorageType<IsNotNullQuestion>(); 539 registerSingletonStorageType<OperandCountQuestion>(); 540 registerSingletonStorageType<OperandCountAtLeastQuestion>(); 541 registerSingletonStorageType<OperationNameQuestion>(); 542 registerSingletonStorageType<ResultCountQuestion>(); 543 registerSingletonStorageType<ResultCountAtLeastQuestion>(); 544 registerSingletonStorageType<TypeQuestion>(); 545 } 546 }; 547 548 //===----------------------------------------------------------------------===// 549 // PredicateBuilder 550 //===----------------------------------------------------------------------===// 551 552 /// This class provides utilities for constructing predicates. 553 class PredicateBuilder { 554 public: PredicateBuilder(PredicateUniquer & uniquer,MLIRContext * ctx)555 PredicateBuilder(PredicateUniquer &uniquer, MLIRContext *ctx) 556 : uniquer(uniquer), ctx(ctx) {} 557 558 //===--------------------------------------------------------------------===// 559 // Positions 560 //===--------------------------------------------------------------------===// 561 562 /// Returns the root operation position. getRoot()563 Position *getRoot() { return OperationPosition::getRoot(uniquer); } 564 565 /// Returns the parent position defining the value held by the given operand. getOperandDefiningOp(Position * p)566 OperationPosition *getOperandDefiningOp(Position *p) { 567 assert((isa<OperandPosition, OperandGroupPosition>(p)) && 568 "expected operand position"); 569 return OperationPosition::get(uniquer, p); 570 } 571 572 /// Returns the operation position equivalent to the given position. getPassthroughOp(Position * p)573 OperationPosition *getPassthroughOp(Position *p) { 574 assert((isa<ForEachPosition>(p)) && "expected users position"); 575 return OperationPosition::get(uniquer, p); 576 } 577 578 /// Returns an attribute position for an attribute of the given operation. getAttribute(OperationPosition * p,StringRef name)579 Position *getAttribute(OperationPosition *p, StringRef name) { 580 return AttributePosition::get(uniquer, p, StringAttr::get(ctx, name)); 581 } 582 583 /// Returns an attribute position for the given attribute. getAttributeLiteral(Attribute attr)584 Position *getAttributeLiteral(Attribute attr) { 585 return AttributeLiteralPosition::get(uniquer, attr); 586 } 587 getForEach(Position * p,unsigned id)588 Position *getForEach(Position *p, unsigned id) { 589 return ForEachPosition::get(uniquer, p, id); 590 } 591 592 /// Returns an operand position for an operand of the given operation. getOperand(OperationPosition * p,unsigned operand)593 Position *getOperand(OperationPosition *p, unsigned operand) { 594 return OperandPosition::get(uniquer, p, operand); 595 } 596 597 /// Returns a position for a group of operands of the given operation. getOperandGroup(OperationPosition * p,Optional<unsigned> group,bool isVariadic)598 Position *getOperandGroup(OperationPosition *p, Optional<unsigned> group, 599 bool isVariadic) { 600 return OperandGroupPosition::get(uniquer, p, group, isVariadic); 601 } getAllOperands(OperationPosition * p)602 Position *getAllOperands(OperationPosition *p) { 603 return getOperandGroup(p, /*group=*/llvm::None, /*isVariadic=*/true); 604 } 605 606 /// Returns a result position for a result of the given operation. getResult(OperationPosition * p,unsigned result)607 Position *getResult(OperationPosition *p, unsigned result) { 608 return ResultPosition::get(uniquer, p, result); 609 } 610 611 /// Returns a position for a group of results of the given operation. getResultGroup(OperationPosition * p,Optional<unsigned> group,bool isVariadic)612 Position *getResultGroup(OperationPosition *p, Optional<unsigned> group, 613 bool isVariadic) { 614 return ResultGroupPosition::get(uniquer, p, group, isVariadic); 615 } getAllResults(OperationPosition * p)616 Position *getAllResults(OperationPosition *p) { 617 return getResultGroup(p, /*group=*/llvm::None, /*isVariadic=*/true); 618 } 619 620 /// Returns a type position for the given entity. getType(Position * p)621 Position *getType(Position *p) { return TypePosition::get(uniquer, p); } 622 623 /// Returns a type position for the given type value. The value is stored 624 /// as either a TypeAttr, or an ArrayAttr of TypeAttr. getTypeLiteral(Attribute attr)625 Position *getTypeLiteral(Attribute attr) { 626 return TypeLiteralPosition::get(uniquer, attr); 627 } 628 629 /// Returns the users of a position using the value at the given operand. getUsers(Position * p,bool useRepresentative)630 UsersPosition *getUsers(Position *p, bool useRepresentative) { 631 assert((isa<OperandPosition, OperandGroupPosition, ResultPosition, 632 ResultGroupPosition>(p)) && 633 "expected result position"); 634 return UsersPosition::get(uniquer, p, useRepresentative); 635 } 636 637 //===--------------------------------------------------------------------===// 638 // Qualifiers 639 //===--------------------------------------------------------------------===// 640 641 /// An ordinal predicate consists of a "Question" and a set of acceptable 642 /// "Answers" (later converted to ordinal values). A predicate will query some 643 /// property of a positional value and decide what to do based on the result. 644 using Predicate = std::pair<Qualifier *, Qualifier *>; 645 646 /// Create a predicate comparing an attribute to a known value. getAttributeConstraint(Attribute attr)647 Predicate getAttributeConstraint(Attribute attr) { 648 return {AttributeQuestion::get(uniquer), 649 AttributeAnswer::get(uniquer, attr)}; 650 } 651 652 /// Create a predicate checking if two values are equal. getEqualTo(Position * pos)653 Predicate getEqualTo(Position *pos) { 654 return {EqualToQuestion::get(uniquer, pos), TrueAnswer::get(uniquer)}; 655 } 656 657 /// Create a predicate checking if two values are not equal. getNotEqualTo(Position * pos)658 Predicate getNotEqualTo(Position *pos) { 659 return {EqualToQuestion::get(uniquer, pos), FalseAnswer::get(uniquer)}; 660 } 661 662 /// Create a predicate that applies a generic constraint. getConstraint(StringRef name,ArrayRef<Position * > pos)663 Predicate getConstraint(StringRef name, ArrayRef<Position *> pos) { 664 return {ConstraintQuestion::get(uniquer, std::make_tuple(name, pos)), 665 TrueAnswer::get(uniquer)}; 666 } 667 668 /// Create a predicate comparing a value with null. getIsNotNull()669 Predicate getIsNotNull() { 670 return {IsNotNullQuestion::get(uniquer), TrueAnswer::get(uniquer)}; 671 } 672 673 /// Create a predicate comparing the number of operands of an operation to a 674 /// known value. getOperandCount(unsigned count)675 Predicate getOperandCount(unsigned count) { 676 return {OperandCountQuestion::get(uniquer), 677 UnsignedAnswer::get(uniquer, count)}; 678 } getOperandCountAtLeast(unsigned count)679 Predicate getOperandCountAtLeast(unsigned count) { 680 return {OperandCountAtLeastQuestion::get(uniquer), 681 UnsignedAnswer::get(uniquer, count)}; 682 } 683 684 /// Create a predicate comparing the name of an operation to a known value. getOperationName(StringRef name)685 Predicate getOperationName(StringRef name) { 686 return {OperationNameQuestion::get(uniquer), 687 OperationNameAnswer::get(uniquer, OperationName(name, ctx))}; 688 } 689 690 /// Create a predicate comparing the number of results of an operation to a 691 /// known value. getResultCount(unsigned count)692 Predicate getResultCount(unsigned count) { 693 return {ResultCountQuestion::get(uniquer), 694 UnsignedAnswer::get(uniquer, count)}; 695 } getResultCountAtLeast(unsigned count)696 Predicate getResultCountAtLeast(unsigned count) { 697 return {ResultCountAtLeastQuestion::get(uniquer), 698 UnsignedAnswer::get(uniquer, count)}; 699 } 700 701 /// Create a predicate comparing the type of an attribute or value to a known 702 /// type. The value is stored as either a TypeAttr, or an ArrayAttr of 703 /// TypeAttr. getTypeConstraint(Attribute type)704 Predicate getTypeConstraint(Attribute type) { 705 return {TypeQuestion::get(uniquer), TypeAnswer::get(uniquer, type)}; 706 } 707 708 private: 709 /// The uniquer used when allocating predicate nodes. 710 PredicateUniquer &uniquer; 711 712 /// The current MLIR context. 713 MLIRContext *ctx; 714 }; 715 716 } // namespace pdl_to_pdl_interp 717 } // namespace mlir 718 719 #endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATE_H_ 720