1 //===- PredicateTree.h - Predicate tree node definitions --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains definitions for nodes of a tree structure for representing 10 // the general control flow within a pattern match. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_ 15 #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_ 16 17 #include "Predicate.h" 18 #include "mlir/Dialect/PDL/IR/PDL.h" 19 #include "llvm/ADT/MapVector.h" 20 21 namespace mlir { 22 namespace pdl_to_pdl_interp { 23 24 class MatcherNode; 25 26 /// A PositionalPredicate is a predicate that is associated with a specific 27 /// positional value. 28 struct PositionalPredicate { 29 PositionalPredicate(Position *pos, 30 const PredicateBuilder::Predicate &predicate) 31 : position(pos), question(predicate.first), answer(predicate.second) {} 32 33 /// The position the predicate is applied to. 34 Position *position; 35 36 /// The question that the predicate applies. 37 Qualifier *question; 38 39 /// The expected answer of the predicate. 40 Qualifier *answer; 41 }; 42 43 //===----------------------------------------------------------------------===// 44 // MatcherNode 45 //===----------------------------------------------------------------------===// 46 47 /// This class represents the base of a predicate matcher node. 48 class MatcherNode { 49 public: 50 virtual ~MatcherNode() = default; 51 52 /// Given a module containing PDL pattern operations, generate a matcher tree 53 /// using the patterns within the given module and return the root matcher 54 /// node. `valueToPosition` is a map that is populated with the original 55 /// pdl values and their corresponding positions in the matcher tree. 56 static std::unique_ptr<MatcherNode> 57 generateMatcherTree(ModuleOp module, PredicateBuilder &builder, 58 DenseMap<Value, Position *> &valueToPosition); 59 60 /// Returns the position on which the question predicate should be checked. 61 Position *getPosition() const { return position; } 62 63 /// Returns the predicate checked on this node. 64 Qualifier *getQuestion() const { return question; } 65 66 /// Returns the node that should be visited if this, or a subsequent node 67 /// fails. 68 std::unique_ptr<MatcherNode> &getFailureNode() { return failureNode; } 69 70 /// Sets the node that should be visited if this, or a subsequent node fails. 71 void setFailureNode(std::unique_ptr<MatcherNode> node) { 72 failureNode = std::move(node); 73 } 74 75 /// Returns the unique type ID of this matcher instance. This should not be 76 /// used directly, and is provided to support type casting. 77 TypeID getMatcherTypeID() const { return matcherTypeID; } 78 79 protected: 80 MatcherNode(TypeID matcherTypeID, Position *position = nullptr, 81 Qualifier *question = nullptr, 82 std::unique_ptr<MatcherNode> failureNode = nullptr); 83 84 private: 85 /// The position on which the predicate should be checked. 86 Position *position; 87 88 /// The predicate that is checked on the given position. 89 Qualifier *question; 90 91 /// The node to visit if this node fails. 92 std::unique_ptr<MatcherNode> failureNode; 93 94 /// An owning store for the failure node if it is owned by this node. 95 std::unique_ptr<MatcherNode> failureNodeStorage; 96 97 /// A unique identifier for the derived matcher node, used for type casting. 98 TypeID matcherTypeID; 99 }; 100 101 //===----------------------------------------------------------------------===// 102 // BoolNode 103 104 /// A BoolNode denotes a question with a boolean-like result. These nodes branch 105 /// to a single node on a successful result, otherwise defaulting to the failure 106 /// node. 107 struct BoolNode : public MatcherNode { 108 BoolNode(Position *position, Qualifier *question, Qualifier *answer, 109 std::unique_ptr<MatcherNode> successNode, 110 std::unique_ptr<MatcherNode> failureNode = nullptr); 111 112 /// Returns if the given matcher node is an instance of this class, used to 113 /// support type casting. 114 static bool classof(const MatcherNode *node) { 115 return node->getMatcherTypeID() == TypeID::get<BoolNode>(); 116 } 117 118 /// Returns the expected answer of this boolean node. 119 Qualifier *getAnswer() const { return answer; } 120 121 /// Returns the node that should be visited on success. 122 std::unique_ptr<MatcherNode> &getSuccessNode() { return successNode; } 123 124 private: 125 /// The expected answer of this boolean node. 126 Qualifier *answer; 127 128 /// The next node if this node succeeds. Otherwise, go to the failure node. 129 std::unique_ptr<MatcherNode> successNode; 130 }; 131 132 //===----------------------------------------------------------------------===// 133 // ExitNode 134 135 /// An ExitNode is a special sentinel node that denotes the end of matcher. 136 struct ExitNode : public MatcherNode { 137 ExitNode() : MatcherNode(TypeID::get<ExitNode>()) {} 138 139 /// Returns if the given matcher node is an instance of this class, used to 140 /// support type casting. 141 static bool classof(const MatcherNode *node) { 142 return node->getMatcherTypeID() == TypeID::get<ExitNode>(); 143 } 144 }; 145 146 //===----------------------------------------------------------------------===// 147 // SuccessNode 148 149 /// A SuccessNode denotes that a given high level pattern has successfully been 150 /// matched. This does not terminate the matcher, as there may be multiple 151 /// successful matches. 152 struct SuccessNode : public MatcherNode { 153 explicit SuccessNode(pdl::PatternOp pattern, 154 std::unique_ptr<MatcherNode> failureNode); 155 156 /// Returns if the given matcher node is an instance of this class, used to 157 /// support type casting. 158 static bool classof(const MatcherNode *node) { 159 return node->getMatcherTypeID() == TypeID::get<SuccessNode>(); 160 } 161 162 /// Return the high level pattern operation that is matched with this node. 163 pdl::PatternOp getPattern() const { return pattern; } 164 165 private: 166 /// The high level pattern operation that was successfully matched with this 167 /// node. 168 pdl::PatternOp pattern; 169 }; 170 171 //===----------------------------------------------------------------------===// 172 // SwitchNode 173 174 /// A SwitchNode denotes a question with multiple potential results. These nodes 175 /// branch to a specific node based on the result of the question. 176 struct SwitchNode : public MatcherNode { 177 SwitchNode(Position *position, Qualifier *question); 178 179 /// Returns if the given matcher node is an instance of this class, used to 180 /// support type casting. 181 static bool classof(const MatcherNode *node) { 182 return node->getMatcherTypeID() == TypeID::get<SwitchNode>(); 183 } 184 185 /// Returns the children of this switch node. The children are contained 186 /// within a mapping between the various case answers to destination matcher 187 /// nodes. 188 using ChildMapT = llvm::MapVector<Qualifier *, std::unique_ptr<MatcherNode>>; 189 ChildMapT &getChildren() { return children; } 190 191 private: 192 /// Switch predicate "answers" select the child. Answers that are not found 193 /// default to the failure node. 194 ChildMapT children; 195 }; 196 197 } // end namespace pdl_to_pdl_interp 198 } // end namespace mlir 199 200 #endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_ 201