1 //===- Nodes.h --------------------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #ifndef MLIR_TOOLS_PDLL_AST_NODES_H_
10 #define MLIR_TOOLS_PDLL_AST_NODES_H_
11
12 #include "mlir/Support/LLVM.h"
13 #include "mlir/Tools/PDLL/AST/Types.h"
14 #include "llvm/ADT/StringMap.h"
15 #include "llvm/ADT/StringRef.h"
16 #include "llvm/Support/SMLoc.h"
17 #include "llvm/Support/SourceMgr.h"
18 #include "llvm/Support/TrailingObjects.h"
19
20 namespace mlir {
21 namespace pdll {
22 namespace ast {
23 class Context;
24 class Decl;
25 class Expr;
26 class NamedAttributeDecl;
27 class OpNameDecl;
28 class VariableDecl;
29
30 //===----------------------------------------------------------------------===//
31 // Name
32 //===----------------------------------------------------------------------===//
33
34 /// This class provides a convenient API for interacting with source names. It
35 /// contains a string name as well as the source location for that name.
36 struct Name {
37 static const Name &create(Context &ctx, StringRef name, SMRange location);
38
39 /// Return the raw string name.
getNameName40 StringRef getName() const { return name; }
41
42 /// Get the location of this name.
getLocName43 SMRange getLoc() const { return location; }
44
45 private:
46 Name() = delete;
47 Name(const Name &) = delete;
48 Name &operator=(const Name &) = delete;
NameName49 Name(StringRef name, SMRange location) : name(name), location(location) {}
50
51 /// The string name of the decl.
52 StringRef name;
53 /// The location of the decl name.
54 SMRange location;
55 };
56
57 //===----------------------------------------------------------------------===//
58 // DeclScope
59 //===----------------------------------------------------------------------===//
60
61 /// This class represents a scope for named AST decls. A scope determines the
62 /// visibility and lifetime of a named declaration.
63 class DeclScope {
64 public:
65 /// Create a new scope with an optional parent scope.
parent(parent)66 DeclScope(DeclScope *parent = nullptr) : parent(parent) {}
67
68 /// Return the parent scope of this scope, or nullptr if there is no parent.
getParentScope()69 DeclScope *getParentScope() { return parent; }
getParentScope()70 const DeclScope *getParentScope() const { return parent; }
71
72 /// Return all of the decls within this scope.
getDecls()73 auto getDecls() const { return llvm::make_second_range(decls); }
74
75 /// Add a new decl to the scope.
76 void add(Decl *decl);
77
78 /// Lookup a decl with the given name starting from this scope. Returns
79 /// nullptr if no decl could be found.
80 Decl *lookup(StringRef name);
81 template <typename T>
lookup(StringRef name)82 T *lookup(StringRef name) {
83 return dyn_cast_or_null<T>(lookup(name));
84 }
lookup(StringRef name)85 const Decl *lookup(StringRef name) const {
86 return const_cast<DeclScope *>(this)->lookup(name);
87 }
88 template <typename T>
lookup(StringRef name)89 const T *lookup(StringRef name) const {
90 return dyn_cast_or_null<T>(lookup(name));
91 }
92
93 private:
94 /// The parent scope, or null if this is a top-level scope.
95 DeclScope *parent;
96 /// The decls defined within this scope.
97 llvm::StringMap<Decl *> decls;
98 };
99
100 //===----------------------------------------------------------------------===//
101 // Node
102 //===----------------------------------------------------------------------===//
103
104 /// This class represents a base AST node. All AST nodes are derived from this
105 /// class, and it contains many of the base functionality for interacting with
106 /// nodes.
107 class Node {
108 public:
109 /// This CRTP class provides several utilies when defining new AST nodes.
110 template <typename T, typename BaseT>
111 class NodeBase : public BaseT {
112 public:
113 using Base = NodeBase<T, BaseT>;
114
115 /// Provide type casting support.
classof(const Node * node)116 static bool classof(const Node *node) {
117 return node->getTypeID() == TypeID::get<T>();
118 }
119
120 protected:
121 template <typename... Args>
NodeBase(SMRange loc,Args &&...args)122 explicit NodeBase(SMRange loc, Args &&...args)
123 : BaseT(TypeID::get<T>(), loc, std::forward<Args>(args)...) {}
124 };
125
126 /// Return the type identifier of this node.
getTypeID()127 TypeID getTypeID() const { return typeID; }
128
129 /// Return the location of this node.
getLoc()130 SMRange getLoc() const { return loc; }
131
132 /// Print this node to the given stream.
133 void print(raw_ostream &os) const;
134
135 /// Walk all of the nodes including, and nested under, this node in pre-order.
136 void walk(function_ref<void(const Node *)> walkFn) const;
137 template <typename WalkFnT, typename ArgT = typename llvm::function_traits<
138 WalkFnT>::template arg_t<0>>
139 std::enable_if_t<!std::is_convertible<const Node *, ArgT>::value>
walk(WalkFnT && walkFn)140 walk(WalkFnT &&walkFn) const {
141 walk([&](const Node *node) {
142 if (const ArgT *derivedNode = dyn_cast<ArgT>(node))
143 walkFn(derivedNode);
144 });
145 }
146
147 protected:
Node(TypeID typeID,SMRange loc)148 Node(TypeID typeID, SMRange loc) : typeID(typeID), loc(loc) {}
149
150 private:
151 /// A unique type identifier for this node.
152 TypeID typeID;
153
154 /// The location of this node.
155 SMRange loc;
156 };
157
158 //===----------------------------------------------------------------------===//
159 // Stmt
160 //===----------------------------------------------------------------------===//
161
162 /// This class represents a base AST Statement node.
163 class Stmt : public Node {
164 public:
165 using Node::Node;
166
167 /// Provide type casting support.
168 static bool classof(const Node *node);
169 };
170
171 //===----------------------------------------------------------------------===//
172 // CompoundStmt
173 //===----------------------------------------------------------------------===//
174
175 /// This statement represents a compound statement, which contains a collection
176 /// of other statements.
177 class CompoundStmt final : public Node::NodeBase<CompoundStmt, Stmt>,
178 private llvm::TrailingObjects<CompoundStmt, Stmt *> {
179 public:
180 static CompoundStmt *create(Context &ctx, SMRange location,
181 ArrayRef<Stmt *> children);
182
183 /// Return the children of this compound statement.
getChildren()184 MutableArrayRef<Stmt *> getChildren() {
185 return {getTrailingObjects<Stmt *>(), numChildren};
186 }
getChildren()187 ArrayRef<Stmt *> getChildren() const {
188 return const_cast<CompoundStmt *>(this)->getChildren();
189 }
begin()190 ArrayRef<Stmt *>::iterator begin() const { return getChildren().begin(); }
end()191 ArrayRef<Stmt *>::iterator end() const { return getChildren().end(); }
192
193 private:
CompoundStmt(SMRange location,unsigned numChildren)194 CompoundStmt(SMRange location, unsigned numChildren)
195 : Base(location), numChildren(numChildren) {}
196
197 /// The number of held children statements.
198 unsigned numChildren;
199
200 // Allow access to various privates.
201 friend class llvm::TrailingObjects<CompoundStmt, Stmt *>;
202 };
203
204 //===----------------------------------------------------------------------===//
205 // LetStmt
206 //===----------------------------------------------------------------------===//
207
208 /// This statement represents a `let` statement in PDLL. This statement is used
209 /// to define variables.
210 class LetStmt final : public Node::NodeBase<LetStmt, Stmt> {
211 public:
212 static LetStmt *create(Context &ctx, SMRange loc, VariableDecl *varDecl);
213
214 /// Return the variable defined by this statement.
getVarDecl()215 VariableDecl *getVarDecl() const { return varDecl; }
216
217 private:
LetStmt(SMRange loc,VariableDecl * varDecl)218 LetStmt(SMRange loc, VariableDecl *varDecl) : Base(loc), varDecl(varDecl) {}
219
220 /// The variable defined by this statement.
221 VariableDecl *varDecl;
222 };
223
224 //===----------------------------------------------------------------------===//
225 // OpRewriteStmt
226 //===----------------------------------------------------------------------===//
227
228 /// This class represents a base operation rewrite statement. Operation rewrite
229 /// statements perform a set of transformations on a given root operation.
230 class OpRewriteStmt : public Stmt {
231 public:
232 /// Provide type casting support.
233 static bool classof(const Node *node);
234
235 /// Return the root operation of this rewrite.
getRootOpExpr()236 Expr *getRootOpExpr() const { return rootOp; }
237
238 protected:
OpRewriteStmt(TypeID typeID,SMRange loc,Expr * rootOp)239 OpRewriteStmt(TypeID typeID, SMRange loc, Expr *rootOp)
240 : Stmt(typeID, loc), rootOp(rootOp) {}
241
242 protected:
243 /// The root operation being rewritten.
244 Expr *rootOp;
245 };
246
247 //===----------------------------------------------------------------------===//
248 // EraseStmt
249
250 /// This statement represents the `erase` statement in PDLL. This statement
251 /// erases the given root operation, corresponding roughly to the
252 /// PatternRewriter::eraseOp API.
253 class EraseStmt final : public Node::NodeBase<EraseStmt, OpRewriteStmt> {
254 public:
255 static EraseStmt *create(Context &ctx, SMRange loc, Expr *rootOp);
256
257 private:
EraseStmt(SMRange loc,Expr * rootOp)258 EraseStmt(SMRange loc, Expr *rootOp) : Base(loc, rootOp) {}
259 };
260
261 //===----------------------------------------------------------------------===//
262 // ReplaceStmt
263
264 /// This statement represents the `replace` statement in PDLL. This statement
265 /// replace the given root operation with a set of values, corresponding roughly
266 /// to the PatternRewriter::replaceOp API.
267 class ReplaceStmt final : public Node::NodeBase<ReplaceStmt, OpRewriteStmt>,
268 private llvm::TrailingObjects<ReplaceStmt, Expr *> {
269 public:
270 static ReplaceStmt *create(Context &ctx, SMRange loc, Expr *rootOp,
271 ArrayRef<Expr *> replExprs);
272
273 /// Return the replacement values of this statement.
getReplExprs()274 MutableArrayRef<Expr *> getReplExprs() {
275 return {getTrailingObjects<Expr *>(), numReplExprs};
276 }
getReplExprs()277 ArrayRef<Expr *> getReplExprs() const {
278 return const_cast<ReplaceStmt *>(this)->getReplExprs();
279 }
280
281 private:
ReplaceStmt(SMRange loc,Expr * rootOp,unsigned numReplExprs)282 ReplaceStmt(SMRange loc, Expr *rootOp, unsigned numReplExprs)
283 : Base(loc, rootOp), numReplExprs(numReplExprs) {}
284
285 /// The number of replacement values within this statement.
286 unsigned numReplExprs;
287
288 /// TrailingObject utilities.
289 friend class llvm::TrailingObjects<ReplaceStmt, Expr *>;
290 };
291
292 //===----------------------------------------------------------------------===//
293 // RewriteStmt
294
295 /// This statement represents an operation rewrite that contains a block of
296 /// nested rewrite commands. This allows for building more complex operation
297 /// rewrites that span across multiple statements, which may be unconnected.
298 class RewriteStmt final : public Node::NodeBase<RewriteStmt, OpRewriteStmt> {
299 public:
300 static RewriteStmt *create(Context &ctx, SMRange loc, Expr *rootOp,
301 CompoundStmt *rewriteBody);
302
303 /// Return the compound rewrite body.
getRewriteBody()304 CompoundStmt *getRewriteBody() const { return rewriteBody; }
305
306 private:
RewriteStmt(SMRange loc,Expr * rootOp,CompoundStmt * rewriteBody)307 RewriteStmt(SMRange loc, Expr *rootOp, CompoundStmt *rewriteBody)
308 : Base(loc, rootOp), rewriteBody(rewriteBody) {}
309
310 /// The body of nested rewriters within this statement.
311 CompoundStmt *rewriteBody;
312 };
313
314 //===----------------------------------------------------------------------===//
315 // ReturnStmt
316 //===----------------------------------------------------------------------===//
317
318 /// This statement represents a return from a "callable" like decl, e.g. a
319 /// Constraint or a Rewrite.
320 class ReturnStmt final : public Node::NodeBase<ReturnStmt, Stmt> {
321 public:
322 static ReturnStmt *create(Context &ctx, SMRange loc, Expr *resultExpr);
323
324 /// Return the result expression of this statement.
getResultExpr()325 Expr *getResultExpr() { return resultExpr; }
getResultExpr()326 const Expr *getResultExpr() const { return resultExpr; }
327
328 /// Set the result expression of this statement.
setResultExpr(Expr * expr)329 void setResultExpr(Expr *expr) { resultExpr = expr; }
330
331 private:
ReturnStmt(SMRange loc,Expr * resultExpr)332 ReturnStmt(SMRange loc, Expr *resultExpr)
333 : Base(loc), resultExpr(resultExpr) {}
334
335 // The result expression of this statement.
336 Expr *resultExpr;
337 };
338
339 //===----------------------------------------------------------------------===//
340 // Expr
341 //===----------------------------------------------------------------------===//
342
343 /// This class represents a base AST Expression node.
344 class Expr : public Stmt {
345 public:
346 /// Return the type of this expression.
getType()347 Type getType() const { return type; }
348
349 /// Provide type casting support.
350 static bool classof(const Node *node);
351
352 protected:
Expr(TypeID typeID,SMRange loc,Type type)353 Expr(TypeID typeID, SMRange loc, Type type) : Stmt(typeID, loc), type(type) {}
354
355 private:
356 /// The type of this expression.
357 Type type;
358 };
359
360 //===----------------------------------------------------------------------===//
361 // AttributeExpr
362 //===----------------------------------------------------------------------===//
363
364 /// This expression represents a literal MLIR Attribute, and contains the
365 /// textual assembly format of that attribute.
366 class AttributeExpr : public Node::NodeBase<AttributeExpr, Expr> {
367 public:
368 static AttributeExpr *create(Context &ctx, SMRange loc, StringRef value);
369
370 /// Get the raw value of this expression. This is the textual assembly format
371 /// of the MLIR Attribute.
getValue()372 StringRef getValue() const { return value; }
373
374 private:
AttributeExpr(Context & ctx,SMRange loc,StringRef value)375 AttributeExpr(Context &ctx, SMRange loc, StringRef value)
376 : Base(loc, AttributeType::get(ctx)), value(value) {}
377
378 /// The value referenced by this expression.
379 StringRef value;
380 };
381
382 //===----------------------------------------------------------------------===//
383 // CallExpr
384 //===----------------------------------------------------------------------===//
385
386 /// This expression represents a call to a decl, such as a
387 /// UserConstraintDecl/UserRewriteDecl.
388 class CallExpr final : public Node::NodeBase<CallExpr, Expr>,
389 private llvm::TrailingObjects<CallExpr, Expr *> {
390 public:
391 static CallExpr *create(Context &ctx, SMRange loc, Expr *callable,
392 ArrayRef<Expr *> arguments, Type resultType);
393
394 /// Return the callable of this call.
getCallableExpr()395 Expr *getCallableExpr() const { return callable; }
396
397 /// Return the arguments of this call.
getArguments()398 MutableArrayRef<Expr *> getArguments() {
399 return {getTrailingObjects<Expr *>(), numArgs};
400 }
getArguments()401 ArrayRef<Expr *> getArguments() const {
402 return const_cast<CallExpr *>(this)->getArguments();
403 }
404
405 private:
CallExpr(SMRange loc,Type type,Expr * callable,unsigned numArgs)406 CallExpr(SMRange loc, Type type, Expr *callable, unsigned numArgs)
407 : Base(loc, type), callable(callable), numArgs(numArgs) {}
408
409 /// The callable of this call.
410 Expr *callable;
411
412 /// The number of arguments of the call.
413 unsigned numArgs;
414
415 /// TrailingObject utilities.
416 friend llvm::TrailingObjects<CallExpr, Expr *>;
417 };
418
419 //===----------------------------------------------------------------------===//
420 // DeclRefExpr
421 //===----------------------------------------------------------------------===//
422
423 /// This expression represents a reference to a Decl node.
424 class DeclRefExpr : public Node::NodeBase<DeclRefExpr, Expr> {
425 public:
426 static DeclRefExpr *create(Context &ctx, SMRange loc, Decl *decl, Type type);
427
428 /// Get the decl referenced by this expression.
getDecl()429 Decl *getDecl() const { return decl; }
430
431 private:
DeclRefExpr(SMRange loc,Decl * decl,Type type)432 DeclRefExpr(SMRange loc, Decl *decl, Type type)
433 : Base(loc, type), decl(decl) {}
434
435 /// The decl referenced by this expression.
436 Decl *decl;
437 };
438
439 //===----------------------------------------------------------------------===//
440 // MemberAccessExpr
441 //===----------------------------------------------------------------------===//
442
443 /// This expression represents a named member or field access of a given parent
444 /// expression.
445 class MemberAccessExpr : public Node::NodeBase<MemberAccessExpr, Expr> {
446 public:
447 static MemberAccessExpr *create(Context &ctx, SMRange loc,
448 const Expr *parentExpr, StringRef memberName,
449 Type type);
450
451 /// Get the parent expression of this access.
getParentExpr()452 const Expr *getParentExpr() const { return parentExpr; }
453
454 /// Return the name of the member being accessed.
getMemberName()455 StringRef getMemberName() const { return memberName; }
456
457 private:
MemberAccessExpr(SMRange loc,const Expr * parentExpr,StringRef memberName,Type type)458 MemberAccessExpr(SMRange loc, const Expr *parentExpr, StringRef memberName,
459 Type type)
460 : Base(loc, type), parentExpr(parentExpr), memberName(memberName) {}
461
462 /// The parent expression of this access.
463 const Expr *parentExpr;
464
465 /// The name of the member being accessed from the parent.
466 StringRef memberName;
467 };
468
469 //===----------------------------------------------------------------------===//
470 // AllResultsMemberAccessExpr
471
472 /// This class represents an instance of MemberAccessExpr that references all
473 /// results of an operation.
474 class AllResultsMemberAccessExpr : public MemberAccessExpr {
475 public:
476 /// Return the member name used for the "all-results" access.
getMemberName()477 static StringRef getMemberName() { return "$results"; }
478
create(Context & ctx,SMRange loc,const Expr * parentExpr,Type type)479 static AllResultsMemberAccessExpr *create(Context &ctx, SMRange loc,
480 const Expr *parentExpr, Type type) {
481 return cast<AllResultsMemberAccessExpr>(
482 MemberAccessExpr::create(ctx, loc, parentExpr, getMemberName(), type));
483 }
484
485 /// Provide type casting support.
classof(const Node * node)486 static bool classof(const Node *node) {
487 const MemberAccessExpr *memAccess = dyn_cast<MemberAccessExpr>(node);
488 return memAccess && memAccess->getMemberName() == getMemberName();
489 }
490 };
491
492 //===----------------------------------------------------------------------===//
493 // OperationExpr
494 //===----------------------------------------------------------------------===//
495
496 /// This expression represents the structural form of an MLIR Operation. It
497 /// represents either an input operation to match, or an operation to create
498 /// within a rewrite.
499 class OperationExpr final
500 : public Node::NodeBase<OperationExpr, Expr>,
501 private llvm::TrailingObjects<OperationExpr, Expr *,
502 NamedAttributeDecl *> {
503 public:
504 static OperationExpr *create(Context &ctx, SMRange loc,
505 const ods::Operation *odsOp,
506 const OpNameDecl *nameDecl,
507 ArrayRef<Expr *> operands,
508 ArrayRef<Expr *> resultTypes,
509 ArrayRef<NamedAttributeDecl *> attributes);
510
511 /// Return the name of the operation, or None if there isn't one.
512 Optional<StringRef> getName() const;
513
514 /// Return the declaration of the operation name.
getNameDecl()515 const OpNameDecl *getNameDecl() const { return nameDecl; }
516
517 /// Return the location of the name of the operation expression, or an invalid
518 /// location if there isn't a name.
getNameLoc()519 SMRange getNameLoc() const { return nameLoc; }
520
521 /// Return the operands of this operation.
getOperands()522 MutableArrayRef<Expr *> getOperands() {
523 return {getTrailingObjects<Expr *>(), numOperands};
524 }
getOperands()525 ArrayRef<Expr *> getOperands() const {
526 return const_cast<OperationExpr *>(this)->getOperands();
527 }
528
529 /// Return the result types of this operation.
getResultTypes()530 MutableArrayRef<Expr *> getResultTypes() {
531 return {getTrailingObjects<Expr *>() + numOperands, numResultTypes};
532 }
getResultTypes()533 MutableArrayRef<Expr *> getResultTypes() const {
534 return const_cast<OperationExpr *>(this)->getResultTypes();
535 }
536
537 /// Return the attributes of this operation.
getAttributes()538 MutableArrayRef<NamedAttributeDecl *> getAttributes() {
539 return {getTrailingObjects<NamedAttributeDecl *>(), numAttributes};
540 }
getAttributes()541 MutableArrayRef<NamedAttributeDecl *> getAttributes() const {
542 return const_cast<OperationExpr *>(this)->getAttributes();
543 }
544
545 private:
OperationExpr(SMRange loc,Type type,const OpNameDecl * nameDecl,unsigned numOperands,unsigned numResultTypes,unsigned numAttributes,SMRange nameLoc)546 OperationExpr(SMRange loc, Type type, const OpNameDecl *nameDecl,
547 unsigned numOperands, unsigned numResultTypes,
548 unsigned numAttributes, SMRange nameLoc)
549 : Base(loc, type), nameDecl(nameDecl), numOperands(numOperands),
550 numResultTypes(numResultTypes), numAttributes(numAttributes),
551 nameLoc(nameLoc) {}
552
553 /// The name decl of this expression.
554 const OpNameDecl *nameDecl;
555
556 /// The number of operands, result types, and attributes of the operation.
557 unsigned numOperands, numResultTypes, numAttributes;
558
559 /// The location of the operation name in the expression if it has a name.
560 SMRange nameLoc;
561
562 /// TrailingObject utilities.
563 friend llvm::TrailingObjects<OperationExpr, Expr *, NamedAttributeDecl *>;
numTrailingObjects(OverloadToken<Expr * >)564 size_t numTrailingObjects(OverloadToken<Expr *>) const {
565 return numOperands + numResultTypes;
566 }
567 };
568
569 //===----------------------------------------------------------------------===//
570 // TupleExpr
571 //===----------------------------------------------------------------------===//
572
573 /// This expression builds a tuple from a set of element values.
574 class TupleExpr final : public Node::NodeBase<TupleExpr, Expr>,
575 private llvm::TrailingObjects<TupleExpr, Expr *> {
576 public:
577 static TupleExpr *create(Context &ctx, SMRange loc, ArrayRef<Expr *> elements,
578 ArrayRef<StringRef> elementNames);
579
580 /// Return the element expressions of this tuple.
getElements()581 MutableArrayRef<Expr *> getElements() {
582 return {getTrailingObjects<Expr *>(), getType().size()};
583 }
getElements()584 ArrayRef<Expr *> getElements() const {
585 return const_cast<TupleExpr *>(this)->getElements();
586 }
587
588 /// Return the tuple result type of this expression.
getType()589 TupleType getType() const { return Base::getType().cast<TupleType>(); }
590
591 private:
TupleExpr(SMRange loc,TupleType type)592 TupleExpr(SMRange loc, TupleType type) : Base(loc, type) {}
593
594 /// TrailingObject utilities.
595 friend class llvm::TrailingObjects<TupleExpr, Expr *>;
596 };
597
598 //===----------------------------------------------------------------------===//
599 // TypeExpr
600 //===----------------------------------------------------------------------===//
601
602 /// This expression represents a literal MLIR Type, and contains the textual
603 /// assembly format of that type.
604 class TypeExpr : public Node::NodeBase<TypeExpr, Expr> {
605 public:
606 static TypeExpr *create(Context &ctx, SMRange loc, StringRef value);
607
608 /// Get the raw value of this expression. This is the textual assembly format
609 /// of the MLIR Type.
getValue()610 StringRef getValue() const { return value; }
611
612 private:
TypeExpr(Context & ctx,SMRange loc,StringRef value)613 TypeExpr(Context &ctx, SMRange loc, StringRef value)
614 : Base(loc, TypeType::get(ctx)), value(value) {}
615
616 /// The value referenced by this expression.
617 StringRef value;
618 };
619
620 //===----------------------------------------------------------------------===//
621 // Decl
622 //===----------------------------------------------------------------------===//
623
624 /// This class represents the base Decl node.
625 class Decl : public Node {
626 public:
627 /// Return the name of the decl, or nullptr if it doesn't have one.
getName()628 const Name *getName() const { return name; }
629
630 /// Provide type casting support.
631 static bool classof(const Node *node);
632
633 /// Set the documentation comment for this decl.
634 void setDocComment(Context &ctx, StringRef comment);
635
636 /// Return the documentation comment attached to this decl if it has been set.
637 /// Otherwise, returns None.
getDocComment()638 Optional<StringRef> getDocComment() const { return docComment; }
639
640 protected:
641 Decl(TypeID typeID, SMRange loc, const Name *name = nullptr)
Node(typeID,loc)642 : Node(typeID, loc), name(name) {}
643
644 private:
645 /// The name of the decl. This is optional for some decls, such as
646 /// PatternDecl.
647 const Name *name;
648
649 /// The documentation comment attached to this decl. Defaults to None if
650 /// the comment is unset/unknown.
651 Optional<StringRef> docComment;
652 };
653
654 //===----------------------------------------------------------------------===//
655 // ConstraintDecl
656 //===----------------------------------------------------------------------===//
657
658 /// This class represents the base of all AST Constraint decls. Constraints
659 /// apply matcher conditions to, and define the type of PDLL variables.
660 class ConstraintDecl : public Decl {
661 public:
662 /// Provide type casting support.
663 static bool classof(const Node *node);
664
665 protected:
666 ConstraintDecl(TypeID typeID, SMRange loc, const Name *name = nullptr)
Decl(typeID,loc,name)667 : Decl(typeID, loc, name) {}
668 };
669
670 /// This class represents a reference to a constraint, and contains a constraint
671 /// and the location of the reference.
672 struct ConstraintRef {
ConstraintRefConstraintRef673 ConstraintRef(const ConstraintDecl *constraint, SMRange refLoc)
674 : constraint(constraint), referenceLoc(refLoc) {}
ConstraintRefConstraintRef675 explicit ConstraintRef(const ConstraintDecl *constraint)
676 : ConstraintRef(constraint, constraint->getLoc()) {}
677
678 const ConstraintDecl *constraint;
679 SMRange referenceLoc;
680 };
681
682 //===----------------------------------------------------------------------===//
683 // CoreConstraintDecl
684 //===----------------------------------------------------------------------===//
685
686 /// This class represents the base of all "core" constraints. Core constraints
687 /// are those that generally represent a concrete IR construct, such as
688 /// `Type`s or `Value`s.
689 class CoreConstraintDecl : public ConstraintDecl {
690 public:
691 /// Provide type casting support.
692 static bool classof(const Node *node);
693
694 protected:
695 CoreConstraintDecl(TypeID typeID, SMRange loc, const Name *name = nullptr)
ConstraintDecl(typeID,loc,name)696 : ConstraintDecl(typeID, loc, name) {}
697 };
698
699 //===----------------------------------------------------------------------===//
700 // AttrConstraintDecl
701
702 /// The class represents an Attribute constraint, and constrains a variable to
703 /// be an Attribute.
704 class AttrConstraintDecl
705 : public Node::NodeBase<AttrConstraintDecl, CoreConstraintDecl> {
706 public:
707 static AttrConstraintDecl *create(Context &ctx, SMRange loc,
708 Expr *typeExpr = nullptr);
709
710 /// Return the optional type the attribute is constrained to.
getTypeExpr()711 Expr *getTypeExpr() { return typeExpr; }
getTypeExpr()712 const Expr *getTypeExpr() const { return typeExpr; }
713
714 protected:
AttrConstraintDecl(SMRange loc,Expr * typeExpr)715 AttrConstraintDecl(SMRange loc, Expr *typeExpr)
716 : Base(loc), typeExpr(typeExpr) {}
717
718 /// An optional type that the attribute is constrained to.
719 Expr *typeExpr;
720 };
721
722 //===----------------------------------------------------------------------===//
723 // OpConstraintDecl
724
725 /// The class represents an Operation constraint, and constrains a variable to
726 /// be an Operation.
727 class OpConstraintDecl
728 : public Node::NodeBase<OpConstraintDecl, CoreConstraintDecl> {
729 public:
730 static OpConstraintDecl *create(Context &ctx, SMRange loc,
731 const OpNameDecl *nameDecl = nullptr);
732
733 /// Return the name of the operation, or None if there isn't one.
734 Optional<StringRef> getName() const;
735
736 /// Return the declaration of the operation name.
getNameDecl()737 const OpNameDecl *getNameDecl() const { return nameDecl; }
738
739 protected:
OpConstraintDecl(SMRange loc,const OpNameDecl * nameDecl)740 explicit OpConstraintDecl(SMRange loc, const OpNameDecl *nameDecl)
741 : Base(loc), nameDecl(nameDecl) {}
742
743 /// The operation name of this constraint.
744 const OpNameDecl *nameDecl;
745 };
746
747 //===----------------------------------------------------------------------===//
748 // TypeConstraintDecl
749
750 /// The class represents a Type constraint, and constrains a variable to be a
751 /// Type.
752 class TypeConstraintDecl
753 : public Node::NodeBase<TypeConstraintDecl, CoreConstraintDecl> {
754 public:
755 static TypeConstraintDecl *create(Context &ctx, SMRange loc);
756
757 protected:
758 using Base::Base;
759 };
760
761 //===----------------------------------------------------------------------===//
762 // TypeRangeConstraintDecl
763
764 /// The class represents a TypeRange constraint, and constrains a variable to be
765 /// a TypeRange.
766 class TypeRangeConstraintDecl
767 : public Node::NodeBase<TypeRangeConstraintDecl, CoreConstraintDecl> {
768 public:
769 static TypeRangeConstraintDecl *create(Context &ctx, SMRange loc);
770
771 protected:
772 using Base::Base;
773 };
774
775 //===----------------------------------------------------------------------===//
776 // ValueConstraintDecl
777
778 /// The class represents a Value constraint, and constrains a variable to be a
779 /// Value.
780 class ValueConstraintDecl
781 : public Node::NodeBase<ValueConstraintDecl, CoreConstraintDecl> {
782 public:
783 static ValueConstraintDecl *create(Context &ctx, SMRange loc, Expr *typeExpr);
784
785 /// Return the optional type the value is constrained to.
getTypeExpr()786 Expr *getTypeExpr() { return typeExpr; }
getTypeExpr()787 const Expr *getTypeExpr() const { return typeExpr; }
788
789 protected:
ValueConstraintDecl(SMRange loc,Expr * typeExpr)790 ValueConstraintDecl(SMRange loc, Expr *typeExpr)
791 : Base(loc), typeExpr(typeExpr) {}
792
793 /// An optional type that the value is constrained to.
794 Expr *typeExpr;
795 };
796
797 //===----------------------------------------------------------------------===//
798 // ValueRangeConstraintDecl
799
800 /// The class represents a ValueRange constraint, and constrains a variable to
801 /// be a ValueRange.
802 class ValueRangeConstraintDecl
803 : public Node::NodeBase<ValueRangeConstraintDecl, CoreConstraintDecl> {
804 public:
805 static ValueRangeConstraintDecl *create(Context &ctx, SMRange loc,
806 Expr *typeExpr = nullptr);
807
808 /// Return the optional type the value range is constrained to.
getTypeExpr()809 Expr *getTypeExpr() { return typeExpr; }
getTypeExpr()810 const Expr *getTypeExpr() const { return typeExpr; }
811
812 protected:
ValueRangeConstraintDecl(SMRange loc,Expr * typeExpr)813 ValueRangeConstraintDecl(SMRange loc, Expr *typeExpr)
814 : Base(loc), typeExpr(typeExpr) {}
815
816 /// An optional type that the value range is constrained to.
817 Expr *typeExpr;
818 };
819
820 //===----------------------------------------------------------------------===//
821 // UserConstraintDecl
822 //===----------------------------------------------------------------------===//
823
824 /// This decl represents a user defined constraint. This is either:
825 /// * an imported native constraint
826 /// - Similar to an external function declaration. This is a native
827 /// constraint defined externally, and imported into PDLL via a
828 /// declaration.
829 /// * a native constraint defined in PDLL
830 /// - This is a native constraint, i.e. a constraint whose implementation is
831 /// defined in C++(or potentially some other non-PDLL language). The
832 /// implementation of this constraint is specified as a string code block
833 /// in PDLL.
834 /// * a PDLL constraint
835 /// - This is a constraint which is defined using only PDLL constructs.
836 class UserConstraintDecl final
837 : public Node::NodeBase<UserConstraintDecl, ConstraintDecl>,
838 llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef> {
839 public:
840 /// Create a native constraint with the given optional code block.
841 static UserConstraintDecl *
842 createNative(Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
843 ArrayRef<VariableDecl *> results, Optional<StringRef> codeBlock,
844 Type resultType, ArrayRef<StringRef> nativeInputTypes = {}) {
845 return createImpl(ctx, name, inputs, nativeInputTypes, results, codeBlock,
846 /*body=*/nullptr, resultType);
847 }
848
849 /// Create a PDLL constraint with the given body.
createPDLL(Context & ctx,const Name & name,ArrayRef<VariableDecl * > inputs,ArrayRef<VariableDecl * > results,const CompoundStmt * body,Type resultType)850 static UserConstraintDecl *createPDLL(Context &ctx, const Name &name,
851 ArrayRef<VariableDecl *> inputs,
852 ArrayRef<VariableDecl *> results,
853 const CompoundStmt *body,
854 Type resultType) {
855 return createImpl(ctx, name, inputs, /*nativeInputTypes=*/llvm::None,
856 results, /*codeBlock=*/llvm::None, body, resultType);
857 }
858
859 /// Return the name of the constraint.
getName()860 const Name &getName() const { return *Decl::getName(); }
861
862 /// Return the input arguments of this constraint.
getInputs()863 MutableArrayRef<VariableDecl *> getInputs() {
864 return {getTrailingObjects<VariableDecl *>(), numInputs};
865 }
getInputs()866 ArrayRef<VariableDecl *> getInputs() const {
867 return const_cast<UserConstraintDecl *>(this)->getInputs();
868 }
869
870 /// Return the explicit native type to use for the given input. Returns None
871 /// if no explicit type was set.
872 Optional<StringRef> getNativeInputType(unsigned index) const;
873
874 /// Return the explicit results of the constraint declaration. May be empty,
875 /// even if the constraint has results (e.g. in the case of inferred results).
getResults()876 MutableArrayRef<VariableDecl *> getResults() {
877 return {getTrailingObjects<VariableDecl *>() + numInputs, numResults};
878 }
getResults()879 ArrayRef<VariableDecl *> getResults() const {
880 return const_cast<UserConstraintDecl *>(this)->getResults();
881 }
882
883 /// Return the optional code block of this constraint, if this is a native
884 /// constraint with a provided implementation.
getCodeBlock()885 Optional<StringRef> getCodeBlock() const { return codeBlock; }
886
887 /// Return the body of this constraint if this constraint is a PDLL
888 /// constraint, otherwise returns nullptr.
getBody()889 const CompoundStmt *getBody() const { return constraintBody; }
890
891 /// Return the result type of this constraint.
getResultType()892 Type getResultType() const { return resultType; }
893
894 /// Returns true if this constraint is external.
isExternal()895 bool isExternal() const { return !constraintBody && !codeBlock; }
896
897 private:
898 /// Create either a PDLL constraint or a native constraint with the given
899 /// components.
900 static UserConstraintDecl *
901 createImpl(Context &ctx, const Name &name, ArrayRef<VariableDecl *> inputs,
902 ArrayRef<StringRef> nativeInputTypes,
903 ArrayRef<VariableDecl *> results, Optional<StringRef> codeBlock,
904 const CompoundStmt *body, Type resultType);
905
UserConstraintDecl(const Name & name,unsigned numInputs,bool hasNativeInputTypes,unsigned numResults,Optional<StringRef> codeBlock,const CompoundStmt * body,Type resultType)906 UserConstraintDecl(const Name &name, unsigned numInputs,
907 bool hasNativeInputTypes, unsigned numResults,
908 Optional<StringRef> codeBlock, const CompoundStmt *body,
909 Type resultType)
910 : Base(name.getLoc(), &name), numInputs(numInputs),
911 numResults(numResults), codeBlock(codeBlock), constraintBody(body),
912 resultType(resultType) {}
913
914 /// The number of inputs to this constraint.
915 unsigned numInputs;
916
917 /// The number of explicit results to this constraint.
918 unsigned numResults;
919
920 /// The optional code block of this constraint.
921 Optional<StringRef> codeBlock;
922
923 /// The optional body of this constraint.
924 const CompoundStmt *constraintBody;
925
926 /// The result type of the constraint.
927 Type resultType;
928
929 /// Flag indicating if this constraint has explicit native input types.
930 bool hasNativeInputTypes;
931
932 /// Allow access to various internals.
933 friend llvm::TrailingObjects<UserConstraintDecl, VariableDecl *, StringRef>;
numTrailingObjects(OverloadToken<VariableDecl * >)934 size_t numTrailingObjects(OverloadToken<VariableDecl *>) const {
935 return numInputs + numResults;
936 }
937 };
938
939 //===----------------------------------------------------------------------===//
940 // NamedAttributeDecl
941 //===----------------------------------------------------------------------===//
942
943 /// This Decl represents a NamedAttribute, and contains a string name and
944 /// attribute value.
945 class NamedAttributeDecl : public Node::NodeBase<NamedAttributeDecl, Decl> {
946 public:
947 static NamedAttributeDecl *create(Context &ctx, const Name &name,
948 Expr *value);
949
950 /// Return the name of the attribute.
getName()951 const Name &getName() const { return *Decl::getName(); }
952
953 /// Return value of the attribute.
getValue()954 Expr *getValue() const { return value; }
955
956 private:
NamedAttributeDecl(const Name & name,Expr * value)957 NamedAttributeDecl(const Name &name, Expr *value)
958 : Base(name.getLoc(), &name), value(value) {}
959
960 /// The value of the attribute.
961 Expr *value;
962 };
963
964 //===----------------------------------------------------------------------===//
965 // OpNameDecl
966 //===----------------------------------------------------------------------===//
967
968 /// This Decl represents an OperationName.
969 class OpNameDecl : public Node::NodeBase<OpNameDecl, Decl> {
970 public:
971 static OpNameDecl *create(Context &ctx, const Name &name);
972 static OpNameDecl *create(Context &ctx, SMRange loc);
973
974 /// Return the name of this operation, or none if the name is unknown.
getName()975 Optional<StringRef> getName() const {
976 const Name *name = Decl::getName();
977 return name ? Optional<StringRef>(name->getName()) : llvm::None;
978 }
979
980 private:
OpNameDecl(const Name & name)981 explicit OpNameDecl(const Name &name) : Base(name.getLoc(), &name) {}
OpNameDecl(SMRange loc)982 explicit OpNameDecl(SMRange loc) : Base(loc) {}
983 };
984
985 //===----------------------------------------------------------------------===//
986 // PatternDecl
987 //===----------------------------------------------------------------------===//
988
989 /// This Decl represents a single Pattern.
990 class PatternDecl : public Node::NodeBase<PatternDecl, Decl> {
991 public:
992 static PatternDecl *create(Context &ctx, SMRange location, const Name *name,
993 Optional<uint16_t> benefit,
994 bool hasBoundedRecursion,
995 const CompoundStmt *body);
996
997 /// Return the benefit of this pattern if specified, or None.
getBenefit()998 Optional<uint16_t> getBenefit() const { return benefit; }
999
1000 /// Return if this pattern has bounded rewrite recursion.
hasBoundedRewriteRecursion()1001 bool hasBoundedRewriteRecursion() const { return hasBoundedRecursion; }
1002
1003 /// Return the body of this pattern.
getBody()1004 const CompoundStmt *getBody() const { return patternBody; }
1005
1006 /// Return the root rewrite statement of this pattern.
getRootRewriteStmt()1007 const OpRewriteStmt *getRootRewriteStmt() const {
1008 return cast<OpRewriteStmt>(patternBody->getChildren().back());
1009 }
1010
1011 private:
PatternDecl(SMRange loc,const Name * name,Optional<uint16_t> benefit,bool hasBoundedRecursion,const CompoundStmt * body)1012 PatternDecl(SMRange loc, const Name *name, Optional<uint16_t> benefit,
1013 bool hasBoundedRecursion, const CompoundStmt *body)
1014 : Base(loc, name), benefit(benefit),
1015 hasBoundedRecursion(hasBoundedRecursion), patternBody(body) {}
1016
1017 /// The benefit of the pattern if it was explicitly specified, None otherwise.
1018 Optional<uint16_t> benefit;
1019
1020 /// If the pattern has properly bounded rewrite recursion or not.
1021 bool hasBoundedRecursion;
1022
1023 /// The compound statement representing the body of the pattern.
1024 const CompoundStmt *patternBody;
1025 };
1026
1027 //===----------------------------------------------------------------------===//
1028 // UserRewriteDecl
1029 //===----------------------------------------------------------------------===//
1030
1031 /// This decl represents a user defined rewrite. This is either:
1032 /// * an imported native rewrite
1033 /// - Similar to an external function declaration. This is a native
1034 /// rewrite defined externally, and imported into PDLL via a declaration.
1035 /// * a native rewrite defined in PDLL
1036 /// - This is a native rewrite, i.e. a rewrite whose implementation is
1037 /// defined in C++(or potentially some other non-PDLL language). The
1038 /// implementation of this rewrite is specified as a string code block
1039 /// in PDLL.
1040 /// * a PDLL rewrite
1041 /// - This is a rewrite which is defined using only PDLL constructs.
1042 class UserRewriteDecl final
1043 : public Node::NodeBase<UserRewriteDecl, Decl>,
1044 llvm::TrailingObjects<UserRewriteDecl, VariableDecl *> {
1045 public:
1046 /// Create a native rewrite with the given optional code block.
createNative(Context & ctx,const Name & name,ArrayRef<VariableDecl * > inputs,ArrayRef<VariableDecl * > results,Optional<StringRef> codeBlock,Type resultType)1047 static UserRewriteDecl *createNative(Context &ctx, const Name &name,
1048 ArrayRef<VariableDecl *> inputs,
1049 ArrayRef<VariableDecl *> results,
1050 Optional<StringRef> codeBlock,
1051 Type resultType) {
1052 return createImpl(ctx, name, inputs, results, codeBlock, /*body=*/nullptr,
1053 resultType);
1054 }
1055
1056 /// Create a PDLL rewrite with the given body.
createPDLL(Context & ctx,const Name & name,ArrayRef<VariableDecl * > inputs,ArrayRef<VariableDecl * > results,const CompoundStmt * body,Type resultType)1057 static UserRewriteDecl *createPDLL(Context &ctx, const Name &name,
1058 ArrayRef<VariableDecl *> inputs,
1059 ArrayRef<VariableDecl *> results,
1060 const CompoundStmt *body,
1061 Type resultType) {
1062 return createImpl(ctx, name, inputs, results, /*codeBlock=*/llvm::None,
1063 body, resultType);
1064 }
1065
1066 /// Return the name of the rewrite.
getName()1067 const Name &getName() const { return *Decl::getName(); }
1068
1069 /// Return the input arguments of this rewrite.
getInputs()1070 MutableArrayRef<VariableDecl *> getInputs() {
1071 return {getTrailingObjects<VariableDecl *>(), numInputs};
1072 }
getInputs()1073 ArrayRef<VariableDecl *> getInputs() const {
1074 return const_cast<UserRewriteDecl *>(this)->getInputs();
1075 }
1076
1077 /// Return the explicit results of the rewrite declaration. May be empty,
1078 /// even if the rewrite has results (e.g. in the case of inferred results).
getResults()1079 MutableArrayRef<VariableDecl *> getResults() {
1080 return {getTrailingObjects<VariableDecl *>() + numInputs, numResults};
1081 }
getResults()1082 ArrayRef<VariableDecl *> getResults() const {
1083 return const_cast<UserRewriteDecl *>(this)->getResults();
1084 }
1085
1086 /// Return the optional code block of this rewrite, if this is a native
1087 /// rewrite with a provided implementation.
getCodeBlock()1088 Optional<StringRef> getCodeBlock() const { return codeBlock; }
1089
1090 /// Return the body of this rewrite if this rewrite is a PDLL rewrite,
1091 /// otherwise returns nullptr.
getBody()1092 const CompoundStmt *getBody() const { return rewriteBody; }
1093
1094 /// Return the result type of this rewrite.
getResultType()1095 Type getResultType() const { return resultType; }
1096
1097 /// Returns true if this rewrite is external.
isExternal()1098 bool isExternal() const { return !rewriteBody && !codeBlock; }
1099
1100 private:
1101 /// Create either a PDLL rewrite or a native rewrite with the given
1102 /// components.
1103 static UserRewriteDecl *createImpl(Context &ctx, const Name &name,
1104 ArrayRef<VariableDecl *> inputs,
1105 ArrayRef<VariableDecl *> results,
1106 Optional<StringRef> codeBlock,
1107 const CompoundStmt *body, Type resultType);
1108
UserRewriteDecl(const Name & name,unsigned numInputs,unsigned numResults,Optional<StringRef> codeBlock,const CompoundStmt * body,Type resultType)1109 UserRewriteDecl(const Name &name, unsigned numInputs, unsigned numResults,
1110 Optional<StringRef> codeBlock, const CompoundStmt *body,
1111 Type resultType)
1112 : Base(name.getLoc(), &name), numInputs(numInputs),
1113 numResults(numResults), codeBlock(codeBlock), rewriteBody(body),
1114 resultType(resultType) {}
1115
1116 /// The number of inputs to this rewrite.
1117 unsigned numInputs;
1118
1119 /// The number of explicit results to this rewrite.
1120 unsigned numResults;
1121
1122 /// The optional code block of this rewrite.
1123 Optional<StringRef> codeBlock;
1124
1125 /// The optional body of this rewrite.
1126 const CompoundStmt *rewriteBody;
1127
1128 /// The result type of the rewrite.
1129 Type resultType;
1130
1131 /// Allow access to various internals.
1132 friend llvm::TrailingObjects<UserRewriteDecl, VariableDecl *>;
1133 };
1134
1135 //===----------------------------------------------------------------------===//
1136 // CallableDecl
1137 //===----------------------------------------------------------------------===//
1138
1139 /// This decl represents a shared interface for all callable decls.
1140 class CallableDecl : public Decl {
1141 public:
1142 /// Return the callable type of this decl.
getCallableType()1143 StringRef getCallableType() const {
1144 if (isa<UserConstraintDecl>(this))
1145 return "constraint";
1146 assert(isa<UserRewriteDecl>(this) && "unknown callable type");
1147 return "rewrite";
1148 }
1149
1150 /// Return the inputs of this decl.
getInputs()1151 ArrayRef<VariableDecl *> getInputs() const {
1152 if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1153 return cst->getInputs();
1154 return cast<UserRewriteDecl>(this)->getInputs();
1155 }
1156
1157 /// Return the result type of this decl.
getResultType()1158 Type getResultType() const {
1159 if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1160 return cst->getResultType();
1161 return cast<UserRewriteDecl>(this)->getResultType();
1162 }
1163
1164 /// Return the explicit results of the declaration. Note that these may be
1165 /// empty, even if the callable has results (e.g. in the case of inferred
1166 /// results).
getResults()1167 ArrayRef<VariableDecl *> getResults() const {
1168 if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1169 return cst->getResults();
1170 return cast<UserRewriteDecl>(this)->getResults();
1171 }
1172
1173 /// Return the optional code block of this callable, if this is a native
1174 /// callable with a provided implementation.
getCodeBlock()1175 Optional<StringRef> getCodeBlock() const {
1176 if (const auto *cst = dyn_cast<UserConstraintDecl>(this))
1177 return cst->getCodeBlock();
1178 return cast<UserRewriteDecl>(this)->getCodeBlock();
1179 }
1180
1181 /// Support LLVM type casting facilities.
classof(const Node * decl)1182 static bool classof(const Node *decl) {
1183 return isa<UserConstraintDecl, UserRewriteDecl>(decl);
1184 }
1185 };
1186
1187 //===----------------------------------------------------------------------===//
1188 // VariableDecl
1189 //===----------------------------------------------------------------------===//
1190
1191 /// This Decl represents the definition of a PDLL variable.
1192 class VariableDecl final
1193 : public Node::NodeBase<VariableDecl, Decl>,
1194 private llvm::TrailingObjects<VariableDecl, ConstraintRef> {
1195 public:
1196 static VariableDecl *create(Context &ctx, const Name &name, Type type,
1197 Expr *initExpr,
1198 ArrayRef<ConstraintRef> constraints);
1199
1200 /// Return the constraints of this variable.
getConstraints()1201 MutableArrayRef<ConstraintRef> getConstraints() {
1202 return {getTrailingObjects<ConstraintRef>(), numConstraints};
1203 }
getConstraints()1204 ArrayRef<ConstraintRef> getConstraints() const {
1205 return const_cast<VariableDecl *>(this)->getConstraints();
1206 }
1207
1208 /// Return the initializer expression of this statement, or nullptr if there
1209 /// was no initializer.
getInitExpr()1210 Expr *getInitExpr() const { return initExpr; }
1211
1212 /// Return the name of the decl.
getName()1213 const Name &getName() const { return *Decl::getName(); }
1214
1215 /// Return the type of the decl.
getType()1216 Type getType() const { return type; }
1217
1218 private:
VariableDecl(const Name & name,Type type,Expr * initExpr,unsigned numConstraints)1219 VariableDecl(const Name &name, Type type, Expr *initExpr,
1220 unsigned numConstraints)
1221 : Base(name.getLoc(), &name), type(type), initExpr(initExpr),
1222 numConstraints(numConstraints) {}
1223
1224 /// The type of the variable.
1225 Type type;
1226
1227 /// The optional initializer expression of this statement.
1228 Expr *initExpr;
1229
1230 /// The number of constraints attached to this variable.
1231 unsigned numConstraints;
1232
1233 /// Allow access to various internals.
1234 friend llvm::TrailingObjects<VariableDecl, ConstraintRef>;
1235 };
1236
1237 //===----------------------------------------------------------------------===//
1238 // Module
1239 //===----------------------------------------------------------------------===//
1240
1241 /// This class represents a top-level AST module.
1242 class Module final : public Node::NodeBase<Module, Node>,
1243 private llvm::TrailingObjects<Module, Decl *> {
1244 public:
1245 static Module *create(Context &ctx, SMLoc loc, ArrayRef<Decl *> children);
1246
1247 /// Return the children of this module.
getChildren()1248 MutableArrayRef<Decl *> getChildren() {
1249 return {getTrailingObjects<Decl *>(), numChildren};
1250 }
getChildren()1251 ArrayRef<Decl *> getChildren() const {
1252 return const_cast<Module *>(this)->getChildren();
1253 }
1254
1255 private:
Module(SMLoc loc,unsigned numChildren)1256 Module(SMLoc loc, unsigned numChildren)
1257 : Base(SMRange{loc, loc}), numChildren(numChildren) {}
1258
1259 /// The number of decls held by this module.
1260 unsigned numChildren;
1261
1262 /// Allow access to various internals.
1263 friend llvm::TrailingObjects<Module, Decl *>;
1264 };
1265
1266 //===----------------------------------------------------------------------===//
1267 // Defered Method Definitions
1268 //===----------------------------------------------------------------------===//
1269
classof(const Node * node)1270 inline bool Decl::classof(const Node *node) {
1271 return isa<ConstraintDecl, NamedAttributeDecl, OpNameDecl, PatternDecl,
1272 UserRewriteDecl, VariableDecl>(node);
1273 }
1274
classof(const Node * node)1275 inline bool ConstraintDecl::classof(const Node *node) {
1276 return isa<CoreConstraintDecl, UserConstraintDecl>(node);
1277 }
1278
classof(const Node * node)1279 inline bool CoreConstraintDecl::classof(const Node *node) {
1280 return isa<AttrConstraintDecl, OpConstraintDecl, TypeConstraintDecl,
1281 TypeRangeConstraintDecl, ValueConstraintDecl,
1282 ValueRangeConstraintDecl>(node);
1283 }
1284
classof(const Node * node)1285 inline bool Expr::classof(const Node *node) {
1286 return isa<AttributeExpr, CallExpr, DeclRefExpr, MemberAccessExpr,
1287 OperationExpr, TupleExpr, TypeExpr>(node);
1288 }
1289
classof(const Node * node)1290 inline bool OpRewriteStmt::classof(const Node *node) {
1291 return isa<EraseStmt, ReplaceStmt, RewriteStmt>(node);
1292 }
1293
classof(const Node * node)1294 inline bool Stmt::classof(const Node *node) {
1295 return isa<CompoundStmt, LetStmt, OpRewriteStmt, Expr>(node);
1296 }
1297
1298 } // namespace ast
1299 } // namespace pdll
1300 } // namespace mlir
1301
1302 #endif // MLIR_TOOLS_PDLL_AST_NODES_H_
1303