1 //===- Nodes.h --------------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef MLIR_TOOLS_PDLL_AST_NODES_H_
10 #define MLIR_TOOLS_PDLL_AST_NODES_H_
11 
12 #include "mlir/Support/LLVM.h"
13 #include "mlir/Tools/PDLL/AST/Types.h"
14 #include "llvm/ADT/StringMap.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/Support/SMLoc.h"
17 #include "llvm/Support/SourceMgr.h"
18 #include "llvm/Support/TrailingObjects.h"
19 
20 namespace mlir {
21 namespace pdll {
22 namespace ast {
23 class Context;
24 class Decl;
25 class Expr;
26 class NamedAttributeDecl;
27 class OpNameDecl;
28 class VariableDecl;
29 
30 //===----------------------------------------------------------------------===//
31 // Name
32 //===----------------------------------------------------------------------===//
33 
34 /// This class provides a convenient API for interacting with source names. It
35 /// contains a string name as well as the source location for that name.
36 struct Name {
37   static const Name &create(Context &ctx, StringRef name, SMRange location);
38 
39   /// Return the raw string name.
getNameName40   StringRef getName() const { return name; }
41 
42   /// Get the location of this name.
getLocName43   SMRange getLoc() const { return location; }
44 
45 private:
46   Name() = delete;
47   Name(const Name &) = delete;
48   Name &operator=(const Name &) = delete;
NameName49   Name(StringRef name, SMRange location) : name(name), location(location) {}
50 
51   /// The string name of the decl.
52   StringRef name;
53   /// The location of the decl name.
54   SMRange location;
55 };
56 
57 //===----------------------------------------------------------------------===//
58 // DeclScope
59 //===----------------------------------------------------------------------===//
60 
61 /// This class represents a scope for named AST decls. A scope determines the
62 /// visibility and lifetime of a named declaration.
63 class DeclScope {
64 public:
65   /// Create a new scope with an optional parent scope.
parent(parent)66   DeclScope(DeclScope *parent = nullptr) : parent(parent) {}
67 
68   /// Return the parent scope of this scope, or nullptr if there is no parent.
getParentScope()69   DeclScope *getParentScope() { return parent; }
getParentScope()70   const DeclScope *getParentScope() const { return parent; }
71 
72   /// Return all of the decls within this scope.
getDecls()73   auto getDecls() const { return llvm::make_second_range(decls); }
74 
75   /// Add a new decl to the scope.
76   void add(Decl *decl);
77 
78   /// Lookup a decl with the given name starting from this scope. Returns
79   /// nullptr if no decl could be found.
80   Decl *lookup(StringRef name);
81   template <typename T>
lookup(StringRef name)82   T *lookup(StringRef name) {
83     return dyn_cast_or_null<T>(lookup(name));
84   }
lookup(StringRef name)85   const Decl *lookup(StringRef name) const {
86     return const_cast<DeclScope *>(this)->lookup(name);
87   }
88   template <typename T>
lookup(StringRef name)89   const T *lookup(StringRef name) const {
90     return dyn_cast_or_null<T>(lookup(name));
91   }
92 
93 private:
94   /// The parent scope, or null if this is a top-level scope.
95   DeclScope *parent;
96   /// The decls defined within this scope.
97   llvm::StringMap<Decl *> decls;
98 };
99 
100 //===----------------------------------------------------------------------===//
101 // Node
102 //===----------------------------------------------------------------------===//
103 
104 /// This class represents a base AST node. All AST nodes are derived from this
105 /// class, and it contains many of the base functionality for interacting with
106 /// nodes.
107 class Node {
108 public:
109   /// This CRTP class provides several utilies when defining new AST nodes.
110   template <typename T, typename BaseT>
111   class NodeBase : public BaseT {
112   public:
113     using Base = NodeBase<T, BaseT>;
114 
115     /// Provide type casting support.
classof(const Node * node)116     static bool classof(const Node *node) {
117       return node->getTypeID() == TypeID::get<T>();
118     }
119 
120   protected:
121     template <typename... Args>
NodeBase(SMRange loc,Args &&...args)122     explicit NodeBase(SMRange loc, Args &&...args)
123         : BaseT(TypeID::get<T>(), loc, std::forward<Args>(args)...) {}
124   };
125 
126   /// Return the type identifier of this node.
getTypeID()127   TypeID getTypeID() const { return typeID; }
128 
129   /// Return the location of this node.
getLoc()130   SMRange getLoc() const { return loc; }
131 
132   /// Print this node to the given stream.
133   void print(raw_ostream &os) const;
134 
135   /// Walk all of the nodes including, and nested under, this node in pre-order.
136   void walk(function_ref<void(const Node *)> walkFn) const;
137   template <typename WalkFnT, typename ArgT = typename llvm::function_traits<
138                                   WalkFnT>::template arg_t<0>>
139   std::enable_if_t<!std::is_convertible<const Node *, ArgT>::value>
walk(WalkFnT && walkFn)140   walk(WalkFnT &&walkFn) const {
141     walk([&](const Node *node) {
142       if (const ArgT *derivedNode = dyn_cast<ArgT>(node))
143         walkFn(derivedNode);
144     });
145   }
146 
147 protected:
Node(TypeID typeID,SMRange loc)148   Node(TypeID typeID, SMRange loc) : typeID(typeID), loc(loc) {}
149 
150 private:
151   /// A unique type identifier for this node.
152   TypeID typeID;
153 
154   /// The location of this node.
155   SMRange loc;
156 };
157 
158 //===----------------------------------------------------------------------===//
159 // Stmt
160 //===----------------------------------------------------------------------===//
161 
162 /// This class represents a base AST Statement node.
163 class Stmt : public Node {
164 public:
165   using Node::Node;
166 
167   /// Provide type casting support.
168   static bool classof(const Node *node);
169 };
170 
171 //===----------------------------------------------------------------------===//
172 // CompoundStmt
173 //===----------------------------------------------------------------------===//
174 
175 /// This statement represents a compound statement, which contains a collection
176 /// of other statements.
177 class CompoundStmt final : public Node::NodeBase<CompoundStmt, Stmt>,
178                            private llvm::TrailingObjects<CompoundStmt, Stmt *> {
179 public:
180   static CompoundStmt *create(Context &ctx, SMRange location,
181                               ArrayRef<Stmt *> children);
182 
183   /// Return the children of this compound statement.
getChildren()184   MutableArrayRef<Stmt *> getChildren() {
185     return {getTrailingObjects<Stmt *>(), numChildren};
186   }
getChildren()187   ArrayRef<Stmt *> getChildren() const {
188     return const_cast<CompoundStmt *>(this)->getChildren();
189   }
begin()190   ArrayRef<Stmt *>::iterator begin() const { return getChildren().begin(); }
end()191   ArrayRef<Stmt *>::iterator end() const { return getChildren().end(); }
192 
193 private:
CompoundStmt(SMRange location,unsigned numChildren)194   CompoundStmt(SMRange location, unsigned numChildren)
195       : Base(location), numChildren(numChildren) {}
196 
197   /// The number of held children statements.
198   unsigned numChildren;
199 
200   // Allow access to various privates.
201   friend class llvm::TrailingObjects<CompoundStmt, Stmt *>;
202 };
203 
204 //===----------------------------------------------------------------------===//
205 // LetStmt
206 //===----------------------------------------------------------------------===//
207 
208 /// This statement represents a `let` statement in PDLL. This statement is used
209 /// to define variables.
210 class LetStmt final : public Node::NodeBase<LetStmt, Stmt> {
211 public:
212   static LetStmt *create(Context &ctx, SMRange loc, VariableDecl *varDecl);
213 
214   /// Return the variable defined by this statement.
getVarDecl()215   VariableDecl *getVarDecl() const { return varDecl; }
216 
217 private:
LetStmt(SMRange loc,VariableDecl * varDecl)218   LetStmt(SMRange loc, VariableDecl *varDecl) : Base(loc), varDecl(varDecl) {}
219 
220   /// The variable defined by this statement.
221   VariableDecl *varDecl;
222 };
223 
224 //===----------------------------------------------------------------------===//
225 // OpRewriteStmt
226 //===----------------------------------------------------------------------===//
227 
228 /// This class represents a base operation rewrite statement. Operation rewrite
229 /// statements perform a set of transformations on a given root operation.
230 class OpRewriteStmt : public Stmt {
231 public:
232   /// Provide type casting support.
233   static bool classof(const Node *node);
234 
235   /// Return the root operation of this rewrite.
getRootOpExpr()236   Expr *getRootOpExpr() const { return rootOp; }
237 
238 protected:
OpRewriteStmt(TypeID typeID,SMRange loc,Expr * rootOp)239   OpRewriteStmt(TypeID typeID, SMRange loc, Expr *rootOp)
240       : Stmt(typeID, loc), rootOp(rootOp) {}
241 
242 protected:
243   /// The root operation being rewritten.
244   Expr *rootOp;
245 };
246 
247 //===----------------------------------------------------------------------===//
248 // EraseStmt
249 
250 /// This statement represents the `erase` statement in PDLL. This statement
251 /// erases the given root operation, corresponding roughly to the
252 /// PatternRewriter::eraseOp API.
253 class EraseStmt final : public Node::NodeBase<EraseStmt, OpRewriteStmt> {
254 public:
255   static EraseStmt *create(Context &ctx, SMRange loc, Expr *rootOp);
256 
257 private:
EraseStmt(SMRange loc,Expr * rootOp)258   EraseStmt(SMRange loc, Expr *rootOp) : Base(loc, rootOp) {}
259 };
260 
261 //===----------------------------------------------------------------------===//
262 // ReplaceStmt
263 
264 /// This statement represents the `replace` statement in PDLL. This statement
265 /// replace the given root operation with a set of values, corresponding roughly
266 /// to the PatternRewriter::replaceOp API.
267 class ReplaceStmt final : public Node::NodeBase<ReplaceStmt, OpRewriteStmt>,
268                           private llvm::TrailingObjects<ReplaceStmt, Expr *> {
269 public:
270   static ReplaceStmt *create(Context &ctx, SMRange loc, Expr *rootOp,
271                              ArrayRef<Expr *> replExprs);
272 
273   /// Return the replacement values of this statement.
getReplExprs()274   MutableArrayRef<Expr *> getReplExprs() {
275     return {getTrailingObjects<Expr *>(), numReplExprs};
276   }
getReplExprs()277   ArrayRef<Expr *> getReplExprs() const {
278     return const_cast<ReplaceStmt *>(this)->getReplExprs();
279   }
280 
281 private:
ReplaceStmt(SMRange loc,Expr * rootOp,unsigned numReplExprs)282   ReplaceStmt(SMRange loc, Expr *rootOp, unsigned numReplExprs)
283       : Base(loc, rootOp), numReplExprs(numReplExprs) {}
284 
285   /// The number of replacement values within this statement.
286   unsigned numReplExprs;
287 
288   /// TrailingObject utilities.
289   friend class llvm::TrailingObjects<ReplaceStmt, Expr *>;
290 };
291 
292 //===----------------------------------------------------------------------===//
293 // RewriteStmt
294 
295 /// This statement represents an operation rewrite that contains a block of
296 /// nested rewrite commands. This allows for building more complex operation
297 /// rewrites that span across multiple statements, which may be unconnected.
298 class RewriteStmt final : public Node::NodeBase<RewriteStmt, OpRewriteStmt> {
299 public:
300   static RewriteStmt *create(Context &ctx, SMRange loc, Expr *rootOp,
301                              CompoundStmt *rewriteBody);
302 
303   /// Return the compound rewrite body.
getRewriteBody()304   CompoundStmt *getRewriteBody() const { return rewriteBody; }
305 
306 private:
RewriteStmt(SMRange loc,Expr * rootOp,CompoundStmt * rewriteBody)307   RewriteStmt(SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
308       : Base(loc, rootOp), rewriteBody(rewriteBody) {}
309 
310   /// The body of nested rewriters within this statement.
311   CompoundStmt *rewriteBody;
312 };
313 
314 //===----------------------------------------------------------------------===//
315 // ReturnStmt
316 //===----------------------------------------------------------------------===//
317 
318 /// This statement represents a return from a "callable" like decl, e.g. a
319 /// Constraint or a Rewrite.
320 class ReturnStmt final : public Node::NodeBase<ReturnStmt, Stmt> {
321 public:
322   static ReturnStmt *create(Context &ctx, SMRange loc, Expr *resultExpr);
323 
324   /// Return the result expression of this statement.
getResultExpr()325   Expr *getResultExpr() { return resultExpr; }
getResultExpr()326   const Expr *getResultExpr() const { return resultExpr; }
327 
328   /// Set the result expression of this statement.
setResultExpr(Expr * expr)329   void setResultExpr(Expr *expr) { resultExpr = expr; }
330 
331 private:
ReturnStmt(SMRange loc,Expr * resultExpr)332   ReturnStmt(SMRange loc, Expr *resultExpr)
333       : Base(loc), resultExpr(resultExpr) {}
334 
335   // The result expression of this statement.
336   Expr *resultExpr;
337 };
338 
339 //===----------------------------------------------------------------------===//
340 // Expr
341 //===----------------------------------------------------------------------===//
342 
343 /// This class represents a base AST Expression node.
344 class Expr : public Stmt {
345 public:
346   /// Return the type of this expression.
getType()347   Type getType() const { return type; }
348 
349   /// Provide type casting support.
350   static bool classof(const Node *node);
351 
352 protected:
Expr(TypeID typeID,SMRange loc,Type type)353   Expr(TypeID typeID, SMRange loc, Type type) : Stmt(typeID, loc), type(type) {}
354 
355 private:
356   /// The type of this expression.
357   Type type;
358 };
359 
360 //===----------------------------------------------------------------------===//
361 // AttributeExpr
362 //===----------------------------------------------------------------------===//
363 
364 /// This expression represents a literal MLIR Attribute, and contains the
365 /// textual assembly format of that attribute.
366 class AttributeExpr : public Node::NodeBase<AttributeExpr, Expr> {
367 public:
368   static AttributeExpr *create(Context &ctx, SMRange loc, StringRef value);
369 
370   /// Get the raw value of this expression. This is the textual assembly format
371   /// of the MLIR Attribute.
getValue()372   StringRef getValue() const { return value; }
373 
374 private:
AttributeExpr(Context & ctx,SMRange loc,StringRef value)375   AttributeExpr(Context &ctx, SMRange loc, StringRef value)
376       : Base(loc, AttributeType::get(ctx)), value(value) {}
377 
378   /// The value referenced by this expression.
379   StringRef value;
380 };
381 
382 //===----------------------------------------------------------------------===//
383 // CallExpr
384 //===----------------------------------------------------------------------===//
385 
386 /// This expression represents a call to a decl, such as a
387 /// UserConstraintDecl/UserRewriteDecl.
388 class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
389                        private llvm::TrailingObjects<CallExpr, Expr *> {
390 public:
391   static CallExpr *create(Context &ctx, SMRange loc, Expr *callable,
392                           ArrayRef<Expr *> arguments, Type resultType);
393 
394   /// Return the callable of this call.
getCallableExpr()395   Expr *getCallableExpr() const { return callable; }
396 
397   /// Return the arguments of this call.
getArguments()398   MutableArrayRef<Expr *> getArguments() {
399     return {getTrailingObjects<Expr *>(), numArgs};
400   }
getArguments()401   ArrayRef<Expr *> getArguments() const {
402     return const_cast<CallExpr *>(this)->getArguments();
403   }
404 
405 private:
CallExpr(SMRange loc,Type type,Expr * callable,unsigned numArgs)406   CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs)
407       : Base(loc, type), callable(callable), numArgs(numArgs) {}
408 
409   /// The callable of this call.
410   Expr *callable;
411 
412   /// The number of arguments of the call.
413   unsigned numArgs;
414 
415   /// TrailingObject utilities.
416   friend llvm::TrailingObjects<CallExpr, Expr *>;
417 };
418 
419 //===----------------------------------------------------------------------===//
420 // DeclRefExpr
421 //===----------------------------------------------------------------------===//
422 
423 /// This expression represents a reference to a Decl node.
424 class DeclRefExpr : public Node::NodeBase<DeclRefExpr, Expr> {
425 public:
426   static DeclRefExpr *create(Context &ctx, SMRange loc, Decl *decl, Type type);
427 
428   /// Get the decl referenced by this expression.
getDecl()429   Decl *getDecl() const { return decl; }
430 
431 private:
DeclRefExpr(SMRange loc,Decl * decl,Type type)432   DeclRefExpr(SMRange loc, Decl *decl, Type type)
433       : Base(loc, type), decl(decl) {}
434 
435   /// The decl referenced by this expression.
436   Decl *decl;
437 };
438 
439 //===----------------------------------------------------------------------===//
440 // MemberAccessExpr
441 //===----------------------------------------------------------------------===//
442 
443 /// This expression represents a named member or field access of a given parent
444 /// expression.
445 class MemberAccessExpr : public Node::NodeBase<MemberAccessExpr, Expr> {
446 public:
447   static MemberAccessExpr *create(Context &ctx, SMRange loc,
448                                   const Expr *parentExpr, StringRef memberName,
449                                   Type type);
450 
451   /// Get the parent expression of this access.
getParentExpr()452   const Expr *getParentExpr() const { return parentExpr; }
453 
454   /// Return the name of the member being accessed.
getMemberName()455   StringRef getMemberName() const { return memberName; }
456 
457 private:
MemberAccessExpr(SMRange loc,const Expr * parentExpr,StringRef memberName,Type type)458   MemberAccessExpr(SMRange loc, const Expr *parentExpr, StringRef memberName,
459                    Type type)
460       : Base(loc, type), parentExpr(parentExpr), memberName(memberName) {}
461 
462   /// The parent expression of this access.
463   const Expr *parentExpr;
464 
465   /// The name of the member being accessed from the parent.
466   StringRef memberName;
467 };
468 
469 //===----------------------------------------------------------------------===//
470 // AllResultsMemberAccessExpr
471 
472 /// This class represents an instance of MemberAccessExpr that references all
473 /// results of an operation.
474 class AllResultsMemberAccessExpr : public MemberAccessExpr {
475 public:
476   /// Return the member name used for the "all-results" access.
getMemberName()477   static StringRef getMemberName() { return "$results"; }
478 
create(Context & ctx,SMRange loc,const Expr * parentExpr,Type type)479   static AllResultsMemberAccessExpr *create(Context &ctx, SMRange loc,
480                                             const Expr *parentExpr, Type type) {
481     return cast<AllResultsMemberAccessExpr>(
482         MemberAccessExpr::create(ctx, loc, parentExpr, getMemberName(), type));
483   }
484 
485   /// Provide type casting support.
classof(const Node * node)486   static bool classof(const Node *node) {
487     const MemberAccessExpr *memAccess = dyn_cast<MemberAccessExpr>(node);
488     return memAccess && memAccess->getMemberName() == getMemberName();
489   }
490 };
491 
492 //===----------------------------------------------------------------------===//
493 // OperationExpr
494 //===----------------------------------------------------------------------===//
495 
496 /// This expression represents the structural form of an MLIR Operation. It
497 /// represents either an input operation to match, or an operation to create
498 /// within a rewrite.
499 class OperationExpr final
500     : public Node::NodeBase<OperationExpr, Expr>,
501       private llvm::TrailingObjects<OperationExpr, Expr *,
502                                     NamedAttributeDecl *> {
503 public:
504   static OperationExpr *create(Context &ctx, SMRange loc,
505                                const ods::Operation *odsOp,
506                                const OpNameDecl *nameDecl,
507                                ArrayRef<Expr *> operands,
508                                ArrayRef<Expr *> resultTypes,
509                                ArrayRef<NamedAttributeDecl *> attributes);
510 
511   /// Return the name of the operation, or None if there isn't one.
512   Optional<StringRef> getName() const;
513 
514   /// Return the declaration of the operation name.
getNameDecl()515   const OpNameDecl *getNameDecl() const { return nameDecl; }
516 
517   /// Return the location of the name of the operation expression, or an invalid
518   /// location if there isn't a name.
getNameLoc()519   SMRange getNameLoc() const { return nameLoc; }
520 
521   /// Return the operands of this operation.
getOperands()522   MutableArrayRef<Expr *> getOperands() {
523     return {getTrailingObjects<Expr *>(), numOperands};
524   }
getOperands()525   ArrayRef<Expr *> getOperands() const {
526     return const_cast<OperationExpr *>(this)->getOperands();
527   }
528 
529   /// Return the result types of this operation.
getResultTypes()530   MutableArrayRef<Expr *> getResultTypes() {
531     return {getTrailingObjects<Expr *>() + numOperands, numResultTypes};
532   }
getResultTypes()533   MutableArrayRef<Expr *> getResultTypes() const {
534     return const_cast<OperationExpr *>(this)->getResultTypes();
535   }
536 
537   /// Return the attributes of this operation.
getAttributes()538   MutableArrayRef<NamedAttributeDecl *> getAttributes() {
539     return {getTrailingObjects<NamedAttributeDecl *>(), numAttributes};
540   }
getAttributes()541   MutableArrayRef<NamedAttributeDecl *> getAttributes() const {
542     return const_cast<OperationExpr *>(this)->getAttributes();
543   }
544 
545 private:
OperationExpr(SMRange loc,Type type,const OpNameDecl * nameDecl,unsigned numOperands,unsigned numResultTypes,unsigned numAttributes,SMRange nameLoc)546   OperationExpr(SMRange loc, Type type, const OpNameDecl *nameDecl,
547                 unsigned numOperands, unsigned numResultTypes,
548                 unsigned numAttributes, SMRange nameLoc)
549       : Base(loc, type), nameDecl(nameDecl), numOperands(numOperands),
550         numResultTypes(numResultTypes), numAttributes(numAttributes),
551         nameLoc(nameLoc) {}
552 
553   /// The name decl of this expression.
554   const OpNameDecl *nameDecl;
555 
556   /// The number of operands, result types, and attributes of the operation.
557   unsigned numOperands, numResultTypes, numAttributes;
558 
559   /// The location of the operation name in the expression if it has a name.
560   SMRange nameLoc;
561 
562   /// TrailingObject utilities.
563   friend llvm::TrailingObjects<OperationExpr, Expr *, NamedAttributeDecl *>;
numTrailingObjects(OverloadToken<Expr * >)564   size_t numTrailingObjects(OverloadToken<Expr *>) const {
565     return numOperands + numResultTypes;
566   }
567 };
568 
569 //===----------------------------------------------------------------------===//
570 // TupleExpr
571 //===----------------------------------------------------------------------===//
572 
573 /// This expression builds a tuple from a set of element values.
574 class TupleExpr final : public Node::NodeBase<TupleExpr, Expr>,
575                         private llvm::TrailingObjects<TupleExpr, Expr *> {
576 public:
577   static TupleExpr *create(Context &ctx, SMRange loc, ArrayRef<Expr *> elements,
578                            ArrayRef<StringRef> elementNames);
579 
580   /// Return the element expressions of this tuple.
getElements()581   MutableArrayRef<Expr *> getElements() {
582     return {getTrailingObjects<Expr *>(), getType().size()};
583   }
getElements()584   ArrayRef<Expr *> getElements() const {
585     return const_cast<TupleExpr *>(this)->getElements();
586   }
587 
588   /// Return the tuple result type of this expression.
getType()589   TupleType getType() const { return Base::getType().cast<TupleType>(); }
590 
591 private:
TupleExpr(SMRange loc,TupleType type)592   TupleExpr(SMRange loc, TupleType type) : Base(loc, type) {}
593 
594   /// TrailingObject utilities.
595   friend class llvm::TrailingObjects<TupleExpr, Expr *>;
596 };
597 
598 //===----------------------------------------------------------------------===//
599 // TypeExpr
600 //===----------------------------------------------------------------------===//
601 
602 /// This expression represents a literal MLIR Type, and contains the textual
603 /// assembly format of that type.
604 class TypeExpr : public Node::NodeBase<TypeExpr, Expr> {
605 public:
606   static TypeExpr *create(Context &ctx, SMRange loc, StringRef value);
607 
608   /// Get the raw value of this expression. This is the textual assembly format
609   /// of the MLIR Type.
getValue()610   StringRef getValue() const { return value; }
611 
612 private:
TypeExpr(Context & ctx,SMRange loc,StringRef value)613   TypeExpr(Context &ctx, SMRange loc, StringRef value)
614       : Base(loc, TypeType::get(ctx)), value(value) {}
615 
616   /// The value referenced by this expression.
617   StringRef value;
618 };
619 
620 //===----------------------------------------------------------------------===//
621 // Decl
622 //===----------------------------------------------------------------------===//
623 
624 /// This class represents the base Decl node.
625 class Decl : public Node {
626 public:
627   /// Return the name of the decl, or nullptr if it doesn't have one.
getName()628   const Name *getName() const { return name; }
629 
630   /// Provide type casting support.
631   static bool classof(const Node *node);
632 
633   /// Set the documentation comment for this decl.
634   void setDocComment(Context &ctx, StringRef comment);
635 
636   /// Return the documentation comment attached to this decl if it has been set.
637   /// Otherwise, returns None.
getDocComment()638   Optional<StringRef> getDocComment() const { return docComment; }
639 
640 protected:
641   Decl(TypeID typeID, SMRange loc, const Name *name = nullptr)
Node(typeID,loc)642       : Node(typeID, loc), name(name) {}
643 
644 private:
645   /// The name of the decl. This is optional for some decls, such as
646   /// PatternDecl.
647   const Name *name;
648 
649   /// The documentation comment attached to this decl. Defaults to None if
650   /// the comment is unset/unknown.
651   Optional<StringRef> docComment;
652 };
653 
654 //===----------------------------------------------------------------------===//
655 // ConstraintDecl
656 //===----------------------------------------------------------------------===//
657 
658 /// This class represents the base of all AST Constraint decls. Constraints
659 /// apply matcher conditions to, and define the type of PDLL variables.
660 class ConstraintDecl : public Decl {
661 public:
662   /// Provide type casting support.
663   static bool classof(const Node *node);
664 
665 protected:
666   ConstraintDecl(TypeID typeID, SMRange loc, const Name *name = nullptr)
Decl(typeID,loc,name)667       : Decl(typeID, loc, name) {}
668 };
669 
670 /// This class represents a reference to a constraint, and contains a constraint
671 /// and the location of the reference.
672 struct ConstraintRef {
ConstraintRefConstraintRef673   ConstraintRef(const ConstraintDecl *constraint, SMRange refLoc)
674       : constraint(constraint), referenceLoc(refLoc) {}
ConstraintRefConstraintRef675   explicit ConstraintRef(const ConstraintDecl *constraint)
676       : ConstraintRef(constraint, constraint->getLoc()) {}
677 
678   const ConstraintDecl *constraint;
679   SMRange referenceLoc;
680 };
681 
682 //===----------------------------------------------------------------------===//
683 // CoreConstraintDecl
684 //===----------------------------------------------------------------------===//
685 
686 /// This class represents the base of all "core" constraints. Core constraints
687 /// are those that generally represent a concrete IR construct, such as
688 /// `Type`s or `Value`s.
689 class CoreConstraintDecl : public ConstraintDecl {
690 public:
691   /// Provide type casting support.
692   static bool classof(const Node *node);
693 
694 protected:
695   CoreConstraintDecl(TypeID typeID, SMRange loc, const Name *name = nullptr)
ConstraintDecl(typeID,loc,name)696       : ConstraintDecl(typeID, loc, name) {}
697 };
698 
699 //===----------------------------------------------------------------------===//
700 // AttrConstraintDecl
701 
702 /// The class represents an Attribute constraint, and constrains a variable to
703 /// be an Attribute.
704 class AttrConstraintDecl
705     : public Node::NodeBase<AttrConstraintDecl, CoreConstraintDecl> {
706 public:
707   static AttrConstraintDecl *create(Context &ctx, SMRange loc,
708                                     Expr *typeExpr = nullptr);
709 
710   /// Return the optional type the attribute is constrained to.
getTypeExpr()711   Expr *getTypeExpr() { return typeExpr; }
getTypeExpr()712   const Expr *getTypeExpr() const { return typeExpr; }
713 
714 protected:
AttrConstraintDecl(SMRange loc,Expr * typeExpr)715   AttrConstraintDecl(SMRange loc, Expr *typeExpr)
716       : Base(loc), typeExpr(typeExpr) {}
717 
718   /// An optional type that the attribute is constrained to.
719   Expr *typeExpr;
720 };
721 
722 //===----------------------------------------------------------------------===//
723 // OpConstraintDecl
724 
725 /// The class represents an Operation constraint, and constrains a variable to
726 /// be an Operation.
727 class OpConstraintDecl
728     : public Node::NodeBase<OpConstraintDecl, CoreConstraintDecl> {
729 public:
730   static OpConstraintDecl *create(Context &ctx, SMRange loc,
731                                   const OpNameDecl *nameDecl = nullptr);
732 
733   /// Return the name of the operation, or None if there isn't one.
734   Optional<StringRef> getName() const;
735 
736   /// Return the declaration of the operation name.
getNameDecl()737   const OpNameDecl *getNameDecl() const { return nameDecl; }
738 
739 protected:
OpConstraintDecl(SMRange loc,const OpNameDecl * nameDecl)740   explicit OpConstraintDecl(SMRange loc, const OpNameDecl *nameDecl)
741       : Base(loc), nameDecl(nameDecl) {}
742 
743   /// The operation name of this constraint.
744   const OpNameDecl *nameDecl;
745 };
746 
747 //===----------------------------------------------------------------------===//
748 // TypeConstraintDecl
749 
750 /// The class represents a Type constraint, and constrains a variable to be a
751 /// Type.
752 class TypeConstraintDecl
753     : public Node::NodeBase<TypeConstraintDecl, CoreConstraintDecl> {
754 public:
755   static TypeConstraintDecl *create(Context &ctx, SMRange loc);
756 
757 protected:
758   using Base::Base;
759 };
760 
761 //===----------------------------------------------------------------------===//
762 // TypeRangeConstraintDecl
763 
764 /// The class represents a TypeRange constraint, and constrains a variable to be
765 /// a TypeRange.
766 class TypeRangeConstraintDecl
767     : public Node::NodeBase<TypeRangeConstraintDecl, CoreConstraintDecl> {
768 public:
769   static TypeRangeConstraintDecl *create(Context &ctx, SMRange loc);
770 
771 protected:
772   using Base::Base;
773 };
774 
775 //===----------------------------------------------------------------------===//
776 // ValueConstraintDecl
777 
778 /// The class represents a Value constraint, and constrains a variable to be a
779 /// Value.
780 class ValueConstraintDecl
781     : public Node::NodeBase<ValueConstraintDecl, CoreConstraintDecl> {
782 public:
783   static ValueConstraintDecl *create(Context &ctx, SMRange loc, Expr *typeExpr);
784 
785   /// Return the optional type the value is constrained to.
getTypeExpr()786   Expr *getTypeExpr() { return typeExpr; }
getTypeExpr()787   const Expr *getTypeExpr() const { return typeExpr; }
788 
789 protected:
ValueConstraintDecl(SMRange loc,Expr * typeExpr)790   ValueConstraintDecl(SMRange loc, Expr *typeExpr)
791       : Base(loc), typeExpr(typeExpr) {}
792 
793   /// An optional type that the value is constrained to.
794   Expr *typeExpr;
795 };
796 
797 //===----------------------------------------------------------------------===//
798 // ValueRangeConstraintDecl
799 
800 /// The class represents a ValueRange constraint, and constrains a variable to
801 /// be a ValueRange.
802 class ValueRangeConstraintDecl
803     : public Node::NodeBase<ValueRangeConstraintDecl, CoreConstraintDecl> {
804 public:
805   static ValueRangeConstraintDecl *create(Context &ctx, SMRange loc,
806                                           Expr *typeExpr = nullptr);
807 
808   /// Return the optional type the value range is constrained to.
getTypeExpr()809   Expr *getTypeExpr() { return typeExpr; }
getTypeExpr()810   const Expr *getTypeExpr() const { return typeExpr; }
811 
812 protected:
ValueRangeConstraintDecl(SMRange loc,Expr * typeExpr)813   ValueRangeConstraintDecl(SMRange loc, Expr *typeExpr)
814       : Base(loc), typeExpr(typeExpr) {}
815 
816   /// An optional type that the value range is constrained to.
817   Expr *typeExpr;
818 };
819 
820 //===----------------------------------------------------------------------===//
821 // UserConstraintDecl
822 //===----------------------------------------------------------------------===//
823 
824 /// This decl represents a user defined constraint. This is either:
825 ///   * an imported native constraint
826 ///     - Similar to an external function declaration. This is a native
827 ///       constraint defined externally, and imported into PDLL via a
828 ///       declaration.
829 ///   * a native constraint defined in PDLL
830 ///     - This is a native constraint, i.e. a constraint whose implementation is
831 ///       defined in C++(or potentially some other non-PDLL language). The
832 ///       implementation of this constraint is specified as a string code block
833 ///       in PDLL.
834 ///   * a PDLL constraint
835 ///     - This is a constraint which is defined using only PDLL constructs.
836 class UserConstraintDecl final
837     : public Node::NodeBase<UserConstraintDecl, ConstraintDecl>,
838       llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef> {
839 public:
840   /// Create a native constraint with the given optional code block.
841   static UserConstraintDecl *
842   createNative(Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
843                ArrayRef<VariableDecl *> results, Optional<StringRef> codeBlock,
844                Type resultType, ArrayRef<StringRef> nativeInputTypes = {}) {
845     return createImpl(ctx, name, inputs, nativeInputTypes, results, codeBlock,
846                       /*body=*/nullptr, resultType);
847   }
848 
849   /// Create a PDLL constraint with the given body.
createPDLL(Context & ctx,const Name & name,ArrayRef<VariableDecl * > inputs,ArrayRef<VariableDecl * > results,const CompoundStmt * body,Type resultType)850   static UserConstraintDecl *createPDLL(Context &ctx, const Name &name,
851                                         ArrayRef<VariableDecl *> inputs,
852                                         ArrayRef<VariableDecl *> results,
853                                         const CompoundStmt *body,
854                                         Type resultType) {
855     return createImpl(ctx, name, inputs, /*nativeInputTypes=*/llvm::None,
856                       results, /*codeBlock=*/llvm::None, body, resultType);
857   }
858 
859   /// Return the name of the constraint.
getName()860   const Name &getName() const { return *Decl::getName(); }
861 
862   /// Return the input arguments of this constraint.
getInputs()863   MutableArrayRef<VariableDecl *> getInputs() {
864     return {getTrailingObjects<VariableDecl *>(), numInputs};
865   }
getInputs()866   ArrayRef<VariableDecl *> getInputs() const {
867     return const_cast<UserConstraintDecl *>(this)->getInputs();
868   }
869 
870   /// Return the explicit native type to use for the given input. Returns None
871   /// if no explicit type was set.
872   Optional<StringRef> getNativeInputType(unsigned index) const;
873 
874   /// Return the explicit results of the constraint declaration. May be empty,
875   /// even if the constraint has results (e.g. in the case of inferred results).
getResults()876   MutableArrayRef<VariableDecl *> getResults() {
877     return {getTrailingObjects<VariableDecl *>() + numInputs, numResults};
878   }
getResults()879   ArrayRef<VariableDecl *> getResults() const {
880     return const_cast<UserConstraintDecl *>(this)->getResults();
881   }
882 
883   /// Return the optional code block of this constraint, if this is a native
884   /// constraint with a provided implementation.
getCodeBlock()885   Optional<StringRef> getCodeBlock() const { return codeBlock; }
886 
887   /// Return the body of this constraint if this constraint is a PDLL
888   /// constraint, otherwise returns nullptr.
getBody()889   const CompoundStmt *getBody() const { return constraintBody; }
890 
891   /// Return the result type of this constraint.
getResultType()892   Type getResultType() const { return resultType; }
893 
894   /// Returns true if this constraint is external.
isExternal()895   bool isExternal() const { return !constraintBody && !codeBlock; }
896 
897 private:
898   /// Create either a PDLL constraint or a native constraint with the given
899   /// components.
900   static UserConstraintDecl *
901   createImpl(Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
902              ArrayRef<StringRef> nativeInputTypes,
903              ArrayRef<VariableDecl *> results, Optional<StringRef> codeBlock,
904              const CompoundStmt *body, Type resultType);
905 
UserConstraintDecl(const Name & name,unsigned numInputs,bool hasNativeInputTypes,unsigned numResults,Optional<StringRef> codeBlock,const CompoundStmt * body,Type resultType)906   UserConstraintDecl(const Name &name, unsigned numInputs,
907                      bool hasNativeInputTypes, unsigned numResults,
908                      Optional<StringRef> codeBlock, const CompoundStmt *body,
909                      Type resultType)
910       : Base(name.getLoc(), &name), numInputs(numInputs),
911         numResults(numResults), codeBlock(codeBlock), constraintBody(body),
912         resultType(resultType) {}
913 
914   /// The number of inputs to this constraint.
915   unsigned numInputs;
916 
917   /// The number of explicit results to this constraint.
918   unsigned numResults;
919 
920   /// The optional code block of this constraint.
921   Optional<StringRef> codeBlock;
922 
923   /// The optional body of this constraint.
924   const CompoundStmt *constraintBody;
925 
926   /// The result type of the constraint.
927   Type resultType;
928 
929   /// Flag indicating if this constraint has explicit native input types.
930   bool hasNativeInputTypes;
931 
932   /// Allow access to various internals.
933   friend llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef>;
numTrailingObjects(OverloadToken<VariableDecl * >)934   size_t numTrailingObjects(OverloadToken<VariableDecl *>) const {
935     return numInputs + numResults;
936   }
937 };
938 
939 //===----------------------------------------------------------------------===//
940 // NamedAttributeDecl
941 //===----------------------------------------------------------------------===//
942 
943 /// This Decl represents a NamedAttribute, and contains a string name and
944 /// attribute value.
945 class NamedAttributeDecl : public Node::NodeBase<NamedAttributeDecl, Decl> {
946 public:
947   static NamedAttributeDecl *create(Context &ctx, const Name &name,
948                                     Expr *value);
949 
950   /// Return the name of the attribute.
getName()951   const Name &getName() const { return *Decl::getName(); }
952 
953   /// Return value of the attribute.
getValue()954   Expr *getValue() const { return value; }
955 
956 private:
NamedAttributeDecl(const Name & name,Expr * value)957   NamedAttributeDecl(const Name &name, Expr *value)
958       : Base(name.getLoc(), &name), value(value) {}
959 
960   /// The value of the attribute.
961   Expr *value;
962 };
963 
964 //===----------------------------------------------------------------------===//
965 // OpNameDecl
966 //===----------------------------------------------------------------------===//
967 
968 /// This Decl represents an OperationName.
969 class OpNameDecl : public Node::NodeBase<OpNameDecl, Decl> {
970 public:
971   static OpNameDecl *create(Context &ctx, const Name &name);
972   static OpNameDecl *create(Context &ctx, SMRange loc);
973 
974   /// Return the name of this operation, or none if the name is unknown.
getName()975   Optional<StringRef> getName() const {
976     const Name *name = Decl::getName();
977     return name ? Optional<StringRef>(name->getName()) : llvm::None;
978   }
979 
980 private:
OpNameDecl(const Name & name)981   explicit OpNameDecl(const Name &name) : Base(name.getLoc(), &name) {}
OpNameDecl(SMRange loc)982   explicit OpNameDecl(SMRange loc) : Base(loc) {}
983 };
984 
985 //===----------------------------------------------------------------------===//
986 // PatternDecl
987 //===----------------------------------------------------------------------===//
988 
989 /// This Decl represents a single Pattern.
990 class PatternDecl : public Node::NodeBase<PatternDecl, Decl> {
991 public:
992   static PatternDecl *create(Context &ctx, SMRange location, const Name *name,
993                              Optional<uint16_t> benefit,
994                              bool hasBoundedRecursion,
995                              const CompoundStmt *body);
996 
997   /// Return the benefit of this pattern if specified, or None.
getBenefit()998   Optional<uint16_t> getBenefit() const { return benefit; }
999 
1000   /// Return if this pattern has bounded rewrite recursion.
hasBoundedRewriteRecursion()1001   bool hasBoundedRewriteRecursion() const { return hasBoundedRecursion; }
1002 
1003   /// Return the body of this pattern.
getBody()1004   const CompoundStmt *getBody() const { return patternBody; }
1005 
1006   /// Return the root rewrite statement of this pattern.
getRootRewriteStmt()1007   const OpRewriteStmt *getRootRewriteStmt() const {
1008     return cast<OpRewriteStmt>(patternBody->getChildren().back());
1009   }
1010 
1011 private:
PatternDecl(SMRange loc,const Name * name,Optional<uint16_t> benefit,bool hasBoundedRecursion,const CompoundStmt * body)1012   PatternDecl(SMRange loc, const Name *name, Optional<uint16_t> benefit,
1013               bool hasBoundedRecursion, const CompoundStmt *body)
1014       : Base(loc, name), benefit(benefit),
1015         hasBoundedRecursion(hasBoundedRecursion), patternBody(body) {}
1016 
1017   /// The benefit of the pattern if it was explicitly specified, None otherwise.
1018   Optional<uint16_t> benefit;
1019 
1020   /// If the pattern has properly bounded rewrite recursion or not.
1021   bool hasBoundedRecursion;
1022 
1023   /// The compound statement representing the body of the pattern.
1024   const CompoundStmt *patternBody;
1025 };
1026 
1027 //===----------------------------------------------------------------------===//
1028 // UserRewriteDecl
1029 //===----------------------------------------------------------------------===//
1030 
1031 /// This decl represents a user defined rewrite. This is either:
1032 ///   * an imported native rewrite
1033 ///     - Similar to an external function declaration. This is a native
1034 ///       rewrite defined externally, and imported into PDLL via a declaration.
1035 ///   * a native rewrite defined in PDLL
1036 ///     - This is a native rewrite, i.e. a rewrite whose implementation is
1037 ///       defined in C++(or potentially some other non-PDLL language). The
1038 ///       implementation of this rewrite is specified as a string code block
1039 ///       in PDLL.
1040 ///   * a PDLL rewrite
1041 ///     - This is a rewrite which is defined using only PDLL constructs.
1042 class UserRewriteDecl final
1043     : public Node::NodeBase<UserRewriteDecl, Decl>,
1044       llvm::TrailingObjects<UserRewriteDecl, VariableDecl *> {
1045 public:
1046   /// Create a native rewrite with the given optional code block.
createNative(Context & ctx,const Name & name,ArrayRef<VariableDecl * > inputs,ArrayRef<VariableDecl * > results,Optional<StringRef> codeBlock,Type resultType)1047   static UserRewriteDecl *createNative(Context &ctx, const Name &name,
1048                                        ArrayRef<VariableDecl *> inputs,
1049                                        ArrayRef<VariableDecl *> results,
1050                                        Optional<StringRef> codeBlock,
1051                                        Type resultType) {
1052     return createImpl(ctx, name, inputs, results, codeBlock, /*body=*/nullptr,
1053                       resultType);
1054   }
1055 
1056   /// Create a PDLL rewrite with the given body.
createPDLL(Context & ctx,const Name & name,ArrayRef<VariableDecl * > inputs,ArrayRef<VariableDecl * > results,const CompoundStmt * body,Type resultType)1057   static UserRewriteDecl *createPDLL(Context &ctx, const Name &name,
1058                                      ArrayRef<VariableDecl *> inputs,
1059                                      ArrayRef<VariableDecl *> results,
1060                                      const CompoundStmt *body,
1061                                      Type resultType) {
1062     return createImpl(ctx, name, inputs, results, /*codeBlock=*/llvm::None,
1063                       body, resultType);
1064   }
1065 
1066   /// Return the name of the rewrite.
getName()1067   const Name &getName() const { return *Decl::getName(); }
1068 
1069   /// Return the input arguments of this rewrite.
getInputs()1070   MutableArrayRef<VariableDecl *> getInputs() {
1071     return {getTrailingObjects<VariableDecl *>(), numInputs};
1072   }
getInputs()1073   ArrayRef<VariableDecl *> getInputs() const {
1074     return const_cast<UserRewriteDecl *>(this)->getInputs();
1075   }
1076 
1077   /// Return the explicit results of the rewrite declaration. May be empty,
1078   /// even if the rewrite has results (e.g. in the case of inferred results).
getResults()1079   MutableArrayRef<VariableDecl *> getResults() {
1080     return {getTrailingObjects<VariableDecl *>() + numInputs, numResults};
1081   }
getResults()1082   ArrayRef<VariableDecl *> getResults() const {
1083     return const_cast<UserRewriteDecl *>(this)->getResults();
1084   }
1085 
1086   /// Return the optional code block of this rewrite, if this is a native
1087   /// rewrite with a provided implementation.
getCodeBlock()1088   Optional<StringRef> getCodeBlock() const { return codeBlock; }
1089 
1090   /// Return the body of this rewrite if this rewrite is a PDLL rewrite,
1091   /// otherwise returns nullptr.
getBody()1092   const CompoundStmt *getBody() const { return rewriteBody; }
1093 
1094   /// Return the result type of this rewrite.
getResultType()1095   Type getResultType() const { return resultType; }
1096 
1097   /// Returns true if this rewrite is external.
isExternal()1098   bool isExternal() const { return !rewriteBody && !codeBlock; }
1099 
1100 private:
1101   /// Create either a PDLL rewrite or a native rewrite with the given
1102   /// components.
1103   static UserRewriteDecl *createImpl(Context &ctx, const Name &name,
1104                                      ArrayRef<VariableDecl *> inputs,
1105                                      ArrayRef<VariableDecl *> results,
1106                                      Optional<StringRef> codeBlock,
1107                                      const CompoundStmt *body, Type resultType);
1108 
UserRewriteDecl(const Name & name,unsigned numInputs,unsigned numResults,Optional<StringRef> codeBlock,const CompoundStmt * body,Type resultType)1109   UserRewriteDecl(const Name &name, unsigned numInputs, unsigned numResults,
1110                   Optional<StringRef> codeBlock, const CompoundStmt *body,
1111                   Type resultType)
1112       : Base(name.getLoc(), &name), numInputs(numInputs),
1113         numResults(numResults), codeBlock(codeBlock), rewriteBody(body),
1114         resultType(resultType) {}
1115 
1116   /// The number of inputs to this rewrite.
1117   unsigned numInputs;
1118 
1119   /// The number of explicit results to this rewrite.
1120   unsigned numResults;
1121 
1122   /// The optional code block of this rewrite.
1123   Optional<StringRef> codeBlock;
1124 
1125   /// The optional body of this rewrite.
1126   const CompoundStmt *rewriteBody;
1127 
1128   /// The result type of the rewrite.
1129   Type resultType;
1130 
1131   /// Allow access to various internals.
1132   friend llvm::TrailingObjects<UserRewriteDecl, VariableDecl *>;
1133 };
1134 
1135 //===----------------------------------------------------------------------===//
1136 // CallableDecl
1137 //===----------------------------------------------------------------------===//
1138 
1139 /// This decl represents a shared interface for all callable decls.
1140 class CallableDecl : public Decl {
1141 public:
1142   /// Return the callable type of this decl.
getCallableType()1143   StringRef getCallableType() const {
1144     if (isa<UserConstraintDecl>(this))
1145       return "constraint";
1146     assert(isa<UserRewriteDecl>(this) && "unknown callable type");
1147     return "rewrite";
1148   }
1149 
1150   /// Return the inputs of this decl.
getInputs()1151   ArrayRef<VariableDecl *> getInputs() const {
1152     if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1153       return cst->getInputs();
1154     return cast<UserRewriteDecl>(this)->getInputs();
1155   }
1156 
1157   /// Return the result type of this decl.
getResultType()1158   Type getResultType() const {
1159     if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1160       return cst->getResultType();
1161     return cast<UserRewriteDecl>(this)->getResultType();
1162   }
1163 
1164   /// Return the explicit results of the declaration. Note that these may be
1165   /// empty, even if the callable has results (e.g. in the case of inferred
1166   /// results).
getResults()1167   ArrayRef<VariableDecl *> getResults() const {
1168     if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1169       return cst->getResults();
1170     return cast<UserRewriteDecl>(this)->getResults();
1171   }
1172 
1173   /// Return the optional code block of this callable, if this is a native
1174   /// callable with a provided implementation.
getCodeBlock()1175   Optional<StringRef> getCodeBlock() const {
1176     if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1177       return cst->getCodeBlock();
1178     return cast<UserRewriteDecl>(this)->getCodeBlock();
1179   }
1180 
1181   /// Support LLVM type casting facilities.
classof(const Node * decl)1182   static bool classof(const Node *decl) {
1183     return isa<UserConstraintDecl, UserRewriteDecl>(decl);
1184   }
1185 };
1186 
1187 //===----------------------------------------------------------------------===//
1188 // VariableDecl
1189 //===----------------------------------------------------------------------===//
1190 
1191 /// This Decl represents the definition of a PDLL variable.
1192 class VariableDecl final
1193     : public Node::NodeBase<VariableDecl, Decl>,
1194       private llvm::TrailingObjects<VariableDecl, ConstraintRef> {
1195 public:
1196   static VariableDecl *create(Context &ctx, const Name &name, Type type,
1197                               Expr *initExpr,
1198                               ArrayRef<ConstraintRef> constraints);
1199 
1200   /// Return the constraints of this variable.
getConstraints()1201   MutableArrayRef<ConstraintRef> getConstraints() {
1202     return {getTrailingObjects<ConstraintRef>(), numConstraints};
1203   }
getConstraints()1204   ArrayRef<ConstraintRef> getConstraints() const {
1205     return const_cast<VariableDecl *>(this)->getConstraints();
1206   }
1207 
1208   /// Return the initializer expression of this statement, or nullptr if there
1209   /// was no initializer.
getInitExpr()1210   Expr *getInitExpr() const { return initExpr; }
1211 
1212   /// Return the name of the decl.
getName()1213   const Name &getName() const { return *Decl::getName(); }
1214 
1215   /// Return the type of the decl.
getType()1216   Type getType() const { return type; }
1217 
1218 private:
VariableDecl(const Name & name,Type type,Expr * initExpr,unsigned numConstraints)1219   VariableDecl(const Name &name, Type type, Expr *initExpr,
1220                unsigned numConstraints)
1221       : Base(name.getLoc(), &name), type(type), initExpr(initExpr),
1222         numConstraints(numConstraints) {}
1223 
1224   /// The type of the variable.
1225   Type type;
1226 
1227   /// The optional initializer expression of this statement.
1228   Expr *initExpr;
1229 
1230   /// The number of constraints attached to this variable.
1231   unsigned numConstraints;
1232 
1233   /// Allow access to various internals.
1234   friend llvm::TrailingObjects<VariableDecl, ConstraintRef>;
1235 };
1236 
1237 //===----------------------------------------------------------------------===//
1238 // Module
1239 //===----------------------------------------------------------------------===//
1240 
1241 /// This class represents a top-level AST module.
1242 class Module final : public Node::NodeBase<Module, Node>,
1243                      private llvm::TrailingObjects<Module, Decl *> {
1244 public:
1245   static Module *create(Context &ctx, SMLoc loc, ArrayRef<Decl *> children);
1246 
1247   /// Return the children of this module.
getChildren()1248   MutableArrayRef<Decl *> getChildren() {
1249     return {getTrailingObjects<Decl *>(), numChildren};
1250   }
getChildren()1251   ArrayRef<Decl *> getChildren() const {
1252     return const_cast<Module *>(this)->getChildren();
1253   }
1254 
1255 private:
Module(SMLoc loc,unsigned numChildren)1256   Module(SMLoc loc, unsigned numChildren)
1257       : Base(SMRange{loc, loc}), numChildren(numChildren) {}
1258 
1259   /// The number of decls held by this module.
1260   unsigned numChildren;
1261 
1262   /// Allow access to various internals.
1263   friend llvm::TrailingObjects<Module, Decl *>;
1264 };
1265 
1266 //===----------------------------------------------------------------------===//
1267 // Defered Method Definitions
1268 //===----------------------------------------------------------------------===//
1269 
classof(const Node * node)1270 inline bool Decl::classof(const Node *node) {
1271   return isa<ConstraintDecl, NamedAttributeDecl, OpNameDecl, PatternDecl,
1272              UserRewriteDecl, VariableDecl>(node);
1273 }
1274 
classof(const Node * node)1275 inline bool ConstraintDecl::classof(const Node *node) {
1276   return isa<CoreConstraintDecl, UserConstraintDecl>(node);
1277 }
1278 
classof(const Node * node)1279 inline bool CoreConstraintDecl::classof(const Node *node) {
1280   return isa<AttrConstraintDecl, OpConstraintDecl, TypeConstraintDecl,
1281              TypeRangeConstraintDecl, ValueConstraintDecl,
1282              ValueRangeConstraintDecl>(node);
1283 }
1284 
classof(const Node * node)1285 inline bool Expr::classof(const Node *node) {
1286   return isa<AttributeExpr, CallExpr, DeclRefExpr, MemberAccessExpr,
1287              OperationExpr, TupleExpr, TypeExpr>(node);
1288 }
1289 
classof(const Node * node)1290 inline bool OpRewriteStmt::classof(const Node *node) {
1291   return isa<EraseStmt, ReplaceStmt, RewriteStmt>(node);
1292 }
1293 
classof(const Node * node)1294 inline bool Stmt::classof(const Node *node) {
1295   return isa<CompoundStmt, LetStmt, OpRewriteStmt, Expr>(node);
1296 }
1297 
1298 } // namespace ast
1299 } // namespace pdll
1300 } // namespace mlir
1301 
1302 #endif // MLIR_TOOLS_PDLL_AST_NODES_H_
1303