1 //===- PredicateTree.h - Predicate tree node definitions --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains definitions for nodes of a tree structure for representing
10 // the general control flow within a pattern match.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #ifndef MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_
15 #define MLIR_LIB_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_
16 
17 #include "Predicate.h"
18 #include "mlir/Dialect/PDL/IR/PDL.h"
19 #include "llvm/ADT/MapVector.h"
20 
21 namespace mlir {
22 namespace pdl_to_pdl_interp {
23 
24 class MatcherNode;
25 
26 /// A PositionalPredicate is a predicate that is associated with a specific
27 /// positional value.
28 struct PositionalPredicate {
29   PositionalPredicate(Position *pos,
30                       const PredicateBuilder::Predicate &predicate)
31       : position(pos), question(predicate.first), answer(predicate.second) {}
32 
33   /// The position the predicate is applied to.
34   Position *position;
35 
36   /// The question that the predicate applies.
37   Qualifier *question;
38 
39   /// The expected answer of the predicate.
40   Qualifier *answer;
41 };
42 
43 //===----------------------------------------------------------------------===//
44 // MatcherNode
45 //===----------------------------------------------------------------------===//
46 
47 /// This class represents the base of a predicate matcher node.
48 class MatcherNode {
49 public:
50   virtual ~MatcherNode() = default;
51 
52   /// Given a module containing PDL pattern operations, generate a matcher tree
53   /// using the patterns within the given module and return the root matcher
54   /// node. `valueToPosition` is a map that is populated with the original
55   /// pdl values and their corresponding positions in the matcher tree.
56   static std::unique_ptr<MatcherNode>
57   generateMatcherTree(ModuleOp module, PredicateBuilder &builder,
58                       DenseMap<Value, Position *> &valueToPosition);
59 
60   /// Returns the position on which the question predicate should be checked.
61   Position *getPosition() const { return position; }
62 
63   /// Returns the predicate checked on this node.
64   Qualifier *getQuestion() const { return question; }
65 
66   /// Returns the node that should be visited if this, or a subsequent node
67   /// fails.
68   std::unique_ptr<MatcherNode> &getFailureNode() { return failureNode; }
69 
70   /// Sets the node that should be visited if this, or a subsequent node fails.
71   void setFailureNode(std::unique_ptr<MatcherNode> node) {
72     failureNode = std::move(node);
73   }
74 
75   /// Returns the unique type ID of this matcher instance. This should not be
76   /// used directly, and is provided to support type casting.
77   TypeID getMatcherTypeID() const { return matcherTypeID; }
78 
79 protected:
80   MatcherNode(TypeID matcherTypeID, Position *position = nullptr,
81               Qualifier *question = nullptr,
82               std::unique_ptr<MatcherNode> failureNode = nullptr);
83 
84 private:
85   /// The position on which the predicate should be checked.
86   Position *position;
87 
88   /// The predicate that is checked on the given position.
89   Qualifier *question;
90 
91   /// The node to visit if this node fails.
92   std::unique_ptr<MatcherNode> failureNode;
93 
94   /// An owning store for the failure node if it is owned by this node.
95   std::unique_ptr<MatcherNode> failureNodeStorage;
96 
97   /// A unique identifier for the derived matcher node, used for type casting.
98   TypeID matcherTypeID;
99 };
100 
101 //===----------------------------------------------------------------------===//
102 // BoolNode
103 
104 /// A BoolNode denotes a question with a boolean-like result. These nodes branch
105 /// to a single node on a successful result, otherwise defaulting to the failure
106 /// node.
107 struct BoolNode : public MatcherNode {
108   BoolNode(Position *position, Qualifier *question, Qualifier *answer,
109            std::unique_ptr<MatcherNode> successNode,
110            std::unique_ptr<MatcherNode> failureNode = nullptr);
111 
112   /// Returns if the given matcher node is an instance of this class, used to
113   /// support type casting.
114   static bool classof(const MatcherNode *node) {
115     return node->getMatcherTypeID() == TypeID::get<BoolNode>();
116   }
117 
118   /// Returns the expected answer of this boolean node.
119   Qualifier *getAnswer() const { return answer; }
120 
121   /// Returns the node that should be visited on success.
122   std::unique_ptr<MatcherNode> &getSuccessNode() { return successNode; }
123 
124 private:
125   /// The expected answer of this boolean node.
126   Qualifier *answer;
127 
128   /// The next node if this node succeeds. Otherwise, go to the failure node.
129   std::unique_ptr<MatcherNode> successNode;
130 };
131 
132 //===----------------------------------------------------------------------===//
133 // ExitNode
134 
135 /// An ExitNode is a special sentinel node that denotes the end of matcher.
136 struct ExitNode : public MatcherNode {
137   ExitNode() : MatcherNode(TypeID::get<ExitNode>()) {}
138 
139   /// Returns if the given matcher node is an instance of this class, used to
140   /// support type casting.
141   static bool classof(const MatcherNode *node) {
142     return node->getMatcherTypeID() == TypeID::get<ExitNode>();
143   }
144 };
145 
146 //===----------------------------------------------------------------------===//
147 // SuccessNode
148 
149 /// A SuccessNode denotes that a given high level pattern has successfully been
150 /// matched. This does not terminate the matcher, as there may be multiple
151 /// successful matches.
152 struct SuccessNode : public MatcherNode {
153   explicit SuccessNode(pdl::PatternOp pattern,
154                        std::unique_ptr<MatcherNode> failureNode);
155 
156   /// Returns if the given matcher node is an instance of this class, used to
157   /// support type casting.
158   static bool classof(const MatcherNode *node) {
159     return node->getMatcherTypeID() == TypeID::get<SuccessNode>();
160   }
161 
162   /// Return the high level pattern operation that is matched with this node.
163   pdl::PatternOp getPattern() const { return pattern; }
164 
165 private:
166   /// The high level pattern operation that was successfully matched with this
167   /// node.
168   pdl::PatternOp pattern;
169 };
170 
171 //===----------------------------------------------------------------------===//
172 // SwitchNode
173 
174 /// A SwitchNode denotes a question with multiple potential results. These nodes
175 /// branch to a specific node based on the result of the question.
176 struct SwitchNode : public MatcherNode {
177   SwitchNode(Position *position, Qualifier *question);
178 
179   /// Returns if the given matcher node is an instance of this class, used to
180   /// support type casting.
181   static bool classof(const MatcherNode *node) {
182     return node->getMatcherTypeID() == TypeID::get<SwitchNode>();
183   }
184 
185   /// Returns the children of this switch node. The children are contained
186   /// within a mapping between the various case answers to destination matcher
187   /// nodes.
188   using ChildMapT = llvm::MapVector<Qualifier *, std::unique_ptr<MatcherNode>>;
189   ChildMapT &getChildren() { return children; }
190 
191 private:
192   /// Switch predicate "answers" select the child. Answers that are not found
193   /// default to the failure node.
194   ChildMapT children;
195 };
196 
197 } // end namespace pdl_to_pdl_interp
198 } // end namespace mlir
199 
200 #endif // MLIR_CONVERSION_PDLTOPDLINTERP_PREDICATETREE_H_
201