1 //===--- ASTTypeTraits.cpp --------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //  Provides a dynamic type identifier and a dynamically typed node container
10 //  that can be used to store an AST base node at runtime in the same storage in
11 //  a type safe way.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "clang/AST/ASTTypeTraits.h"
16 #include "clang/AST/ASTContext.h"
17 #include "clang/AST/Attr.h"
18 #include "clang/AST/DeclCXX.h"
19 #include "clang/AST/NestedNameSpecifier.h"
20 #include "clang/AST/OpenMPClause.h"
21 
22 using namespace clang;
23 
24 const ASTNodeKind::KindInfo ASTNodeKind::AllKindInfo[] = {
25     {NKI_None, "<None>"},
26     {NKI_None, "TemplateArgument"},
27     {NKI_None, "TemplateArgumentLoc"},
28     {NKI_None, "TemplateName"},
29     {NKI_None, "NestedNameSpecifierLoc"},
30     {NKI_None, "QualType"},
31     {NKI_None, "TypeLoc"},
32     {NKI_None, "CXXBaseSpecifier"},
33     {NKI_None, "CXXCtorInitializer"},
34     {NKI_None, "NestedNameSpecifier"},
35     {NKI_None, "Decl"},
36 #define DECL(DERIVED, BASE) { NKI_##BASE, #DERIVED "Decl" },
37 #include "clang/AST/DeclNodes.inc"
38     {NKI_None, "Stmt"},
39 #define STMT(DERIVED, BASE) { NKI_##BASE, #DERIVED },
40 #include "clang/AST/StmtNodes.inc"
41     {NKI_None, "Type"},
42 #define TYPE(DERIVED, BASE) { NKI_##BASE, #DERIVED "Type" },
43 #include "clang/AST/TypeNodes.inc"
44     {NKI_None, "OMPClause"},
45 #define GEN_CLANG_CLAUSE_CLASS
46 #define CLAUSE_CLASS(Enum, Str, Class) {NKI_OMPClause, #Class},
47 #include "llvm/Frontend/OpenMP/OMP.inc"
48     {NKI_None, "Attr"},
49 #define ATTR(A) {NKI_Attr, #A "Attr"},
50 #include "clang/Basic/AttrList.inc"
51 };
52 
53 bool ASTNodeKind::isBaseOf(ASTNodeKind Other, unsigned *Distance) const {
54   return isBaseOf(KindId, Other.KindId, Distance);
55 }
56 
57 bool ASTNodeKind::isBaseOf(NodeKindId Base, NodeKindId Derived,
58                            unsigned *Distance) {
59   if (Base == NKI_None || Derived == NKI_None) return false;
60   unsigned Dist = 0;
61   while (Derived != Base && Derived != NKI_None) {
62     Derived = AllKindInfo[Derived].ParentId;
63     ++Dist;
64   }
65   if (Distance)
66     *Distance = Dist;
67   return Derived == Base;
68 }
69 
70 ASTNodeKind ASTNodeKind::getCladeKind() const {
71   NodeKindId LastId = KindId;
72   while (LastId) {
73     NodeKindId ParentId = AllKindInfo[LastId].ParentId;
74     if (ParentId == NKI_None)
75       return LastId;
76     LastId = ParentId;
77   }
78   return NKI_None;
79 }
80 
81 StringRef ASTNodeKind::asStringRef() const { return AllKindInfo[KindId].Name; }
82 
83 ASTNodeKind ASTNodeKind::getMostDerivedType(ASTNodeKind Kind1,
84                                             ASTNodeKind Kind2) {
85   if (Kind1.isBaseOf(Kind2)) return Kind2;
86   if (Kind2.isBaseOf(Kind1)) return Kind1;
87   return ASTNodeKind();
88 }
89 
90 ASTNodeKind ASTNodeKind::getMostDerivedCommonAncestor(ASTNodeKind Kind1,
91                                                       ASTNodeKind Kind2) {
92   NodeKindId Parent = Kind1.KindId;
93   while (!isBaseOf(Parent, Kind2.KindId, nullptr) && Parent != NKI_None) {
94     Parent = AllKindInfo[Parent].ParentId;
95   }
96   return ASTNodeKind(Parent);
97 }
98 
99 ASTNodeKind ASTNodeKind::getFromNode(const Decl &D) {
100   switch (D.getKind()) {
101 #define DECL(DERIVED, BASE)                                                    \
102     case Decl::DERIVED: return ASTNodeKind(NKI_##DERIVED##Decl);
103 #define ABSTRACT_DECL(D)
104 #include "clang/AST/DeclNodes.inc"
105   };
106   llvm_unreachable("invalid decl kind");
107 }
108 
109 ASTNodeKind ASTNodeKind::getFromNode(const Stmt &S) {
110   switch (S.getStmtClass()) {
111     case Stmt::NoStmtClass: return NKI_None;
112 #define STMT(CLASS, PARENT)                                                    \
113     case Stmt::CLASS##Class: return ASTNodeKind(NKI_##CLASS);
114 #define ABSTRACT_STMT(S)
115 #include "clang/AST/StmtNodes.inc"
116   }
117   llvm_unreachable("invalid stmt kind");
118 }
119 
120 ASTNodeKind ASTNodeKind::getFromNode(const Type &T) {
121   switch (T.getTypeClass()) {
122 #define TYPE(Class, Base)                                                      \
123     case Type::Class: return ASTNodeKind(NKI_##Class##Type);
124 #define ABSTRACT_TYPE(Class, Base)
125 #include "clang/AST/TypeNodes.inc"
126   }
127   llvm_unreachable("invalid type kind");
128  }
129 
130 ASTNodeKind ASTNodeKind::getFromNode(const OMPClause &C) {
131   switch (C.getClauseKind()) {
132 #define GEN_CLANG_CLAUSE_CLASS
133 #define CLAUSE_CLASS(Enum, Str, Class)                                         \
134   case llvm::omp::Clause::Enum:                                                \
135     return ASTNodeKind(NKI_##Class);
136 #define CLAUSE_NO_CLASS(Enum, Str)                                             \
137   case llvm::omp::Clause::Enum:                                                \
138     llvm_unreachable("unexpected OpenMP clause kind");
139 #include "llvm/Frontend/OpenMP/OMP.inc"
140   }
141   llvm_unreachable("invalid omp clause kind");
142 }
143 
144 ASTNodeKind ASTNodeKind::getFromNode(const Attr &A) {
145   switch (A.getKind()) {
146 #define ATTR(A)                                                                \
147   case attr::A:                                                                \
148     return ASTNodeKind(NKI_##A##Attr);
149 #include "clang/Basic/AttrList.inc"
150   }
151   llvm_unreachable("invalid attr kind");
152 }
153 
154 void DynTypedNode::print(llvm::raw_ostream &OS,
155                          const PrintingPolicy &PP) const {
156   if (const TemplateArgument *TA = get<TemplateArgument>())
157     TA->print(PP, OS, /*IncludeType*/ true);
158   else if (const TemplateArgumentLoc *TAL = get<TemplateArgumentLoc>())
159     TAL->getArgument().print(PP, OS, /*IncludeType*/ true);
160   else if (const TemplateName *TN = get<TemplateName>())
161     TN->print(OS, PP);
162   else if (const NestedNameSpecifier *NNS = get<NestedNameSpecifier>())
163     NNS->print(OS, PP);
164   else if (const NestedNameSpecifierLoc *NNSL = get<NestedNameSpecifierLoc>()) {
165     if (const NestedNameSpecifier *NNS = NNSL->getNestedNameSpecifier())
166       NNS->print(OS, PP);
167     else
168       OS << "(empty NestedNameSpecifierLoc)";
169   } else if (const QualType *QT = get<QualType>())
170     QT->print(OS, PP);
171   else if (const TypeLoc *TL = get<TypeLoc>())
172     TL->getType().print(OS, PP);
173   else if (const Decl *D = get<Decl>())
174     D->print(OS, PP);
175   else if (const Stmt *S = get<Stmt>())
176     S->printPretty(OS, nullptr, PP);
177   else if (const Type *T = get<Type>())
178     QualType(T, 0).print(OS, PP);
179   else if (const Attr *A = get<Attr>())
180     A->printPretty(OS, PP);
181   else
182     OS << "Unable to print values of type " << NodeKind.asStringRef() << "\n";
183 }
184 
185 void DynTypedNode::dump(llvm::raw_ostream &OS,
186                         const ASTContext &Context) const {
187   if (const Decl *D = get<Decl>())
188     D->dump(OS);
189   else if (const Stmt *S = get<Stmt>())
190     S->dump(OS, Context);
191   else if (const Type *T = get<Type>())
192     T->dump(OS, Context);
193   else
194     OS << "Unable to dump values of type " << NodeKind.asStringRef() << "\n";
195 }
196 
197 SourceRange DynTypedNode::getSourceRange() const {
198   if (const CXXCtorInitializer *CCI = get<CXXCtorInitializer>())
199     return CCI->getSourceRange();
200   if (const NestedNameSpecifierLoc *NNSL = get<NestedNameSpecifierLoc>())
201     return NNSL->getSourceRange();
202   if (const TypeLoc *TL = get<TypeLoc>())
203     return TL->getSourceRange();
204   if (const Decl *D = get<Decl>())
205     return D->getSourceRange();
206   if (const Stmt *S = get<Stmt>())
207     return S->getSourceRange();
208   if (const TemplateArgumentLoc *TAL = get<TemplateArgumentLoc>())
209     return TAL->getSourceRange();
210   if (const auto *C = get<OMPClause>())
211     return SourceRange(C->getBeginLoc(), C->getEndLoc());
212   if (const auto *CBS = get<CXXBaseSpecifier>())
213     return CBS->getSourceRange();
214   if (const auto *A = get<Attr>())
215     return A->getRange();
216   return SourceRange();
217 }
218