1 //===- BuildTree.cpp ------------------------------------------*- C++ -*-=====//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 #include "clang/Tooling/Syntax/BuildTree.h"
9 #include "clang/AST/ASTFwd.h"
10 #include "clang/AST/Decl.h"
11 #include "clang/AST/DeclBase.h"
12 #include "clang/AST/DeclCXX.h"
13 #include "clang/AST/DeclarationName.h"
14 #include "clang/AST/RecursiveASTVisitor.h"
15 #include "clang/AST/Stmt.h"
16 #include "clang/AST/TypeLoc.h"
17 #include "clang/AST/TypeLocVisitor.h"
18 #include "clang/Basic/LLVM.h"
19 #include "clang/Basic/SourceLocation.h"
20 #include "clang/Basic/SourceManager.h"
21 #include "clang/Basic/Specifiers.h"
22 #include "clang/Basic/TokenKinds.h"
23 #include "clang/Lex/Lexer.h"
24 #include "clang/Tooling/Syntax/Nodes.h"
25 #include "clang/Tooling/Syntax/Tokens.h"
26 #include "clang/Tooling/Syntax/Tree.h"
27 #include "llvm/ADT/ArrayRef.h"
28 #include "llvm/ADT/STLExtras.h"
29 #include "llvm/ADT/ScopeExit.h"
30 #include "llvm/ADT/SmallVector.h"
31 #include "llvm/Support/Allocator.h"
32 #include "llvm/Support/Casting.h"
33 #include "llvm/Support/Compiler.h"
34 #include "llvm/Support/FormatVariadic.h"
35 #include "llvm/Support/MemoryBuffer.h"
36 #include "llvm/Support/raw_ostream.h"
37 #include <map>
38 
39 using namespace clang;
40 
41 LLVM_ATTRIBUTE_UNUSED
42 static bool isImplicitExpr(clang::Expr *E) { return E->IgnoreImplicit() != E; }
43 
44 static SourceLocation getQualifiedNameStart(DeclaratorDecl *D) {
45   auto DN = D->getDeclName();
46   bool IsAnonymous = DN.isIdentifier() && !DN.getAsIdentifierInfo();
47   if (IsAnonymous)
48     return SourceLocation();
49   return D->getQualifierLoc() ? D->getQualifierLoc().getBeginLoc()
50                               : D->getLocation();
51 }
52 
53 namespace {
54 /// Get start location of the Declarator from the TypeLoc.
55 /// E.g.:
56 ///   loc of `(` in `int (a)`
57 ///   loc of `*` in `int *(a)`
58 ///   loc of the first `(` in `int (*a)(int)`
59 ///   loc of the `*` in `int *(a)(int)`
60 ///   loc of the first `*` in `const int *const *volatile a;`
61 ///
62 /// It is non-trivial to get the start location because TypeLocs are stored
63 /// inside out. In the example above `*volatile` is the TypeLoc returned
64 /// by `Decl.getTypeSourceInfo()`, and `*const` is what `.getPointeeLoc()`
65 /// returns.
66 struct GetStartLoc : TypeLocVisitor<GetStartLoc, SourceLocation> {
67   SourceLocation VisitParenTypeLoc(ParenTypeLoc T) {
68     auto L = Visit(T.getInnerLoc());
69     if (L.isValid())
70       return L;
71     return T.getLParenLoc();
72   }
73 
74   // Types spelled in the prefix part of the declarator.
75   SourceLocation VisitPointerTypeLoc(PointerTypeLoc T) {
76     return HandlePointer(T);
77   }
78 
79   SourceLocation VisitMemberPointerTypeLoc(MemberPointerTypeLoc T) {
80     return HandlePointer(T);
81   }
82 
83   SourceLocation VisitBlockPointerTypeLoc(BlockPointerTypeLoc T) {
84     return HandlePointer(T);
85   }
86 
87   SourceLocation VisitReferenceTypeLoc(ReferenceTypeLoc T) {
88     return HandlePointer(T);
89   }
90 
91   SourceLocation VisitObjCObjectPointerTypeLoc(ObjCObjectPointerTypeLoc T) {
92     return HandlePointer(T);
93   }
94 
95   // All other cases are not important, as they are either part of declaration
96   // specifiers (e.g. inheritors of TypeSpecTypeLoc) or introduce modifiers on
97   // existing declarators (e.g. QualifiedTypeLoc). They cannot start the
98   // declarator themselves, but their underlying type can.
99   SourceLocation VisitTypeLoc(TypeLoc T) {
100     auto N = T.getNextTypeLoc();
101     if (!N)
102       return SourceLocation();
103     return Visit(N);
104   }
105 
106   SourceLocation VisitFunctionProtoTypeLoc(FunctionProtoTypeLoc T) {
107     if (T.getTypePtr()->hasTrailingReturn())
108       return SourceLocation(); // avoid recursing into the suffix of declarator.
109     return VisitTypeLoc(T);
110   }
111 
112 private:
113   template <class PtrLoc> SourceLocation HandlePointer(PtrLoc T) {
114     auto L = Visit(T.getPointeeLoc());
115     if (L.isValid())
116       return L;
117     return T.getLocalSourceRange().getBegin();
118   }
119 };
120 } // namespace
121 
122 /// Gets the range of declarator as defined by the C++ grammar. E.g.
123 ///     `int a;` -> range of `a`,
124 ///     `int *a;` -> range of `*a`,
125 ///     `int a[10];` -> range of `a[10]`,
126 ///     `int a[1][2][3];` -> range of `a[1][2][3]`,
127 ///     `int *a = nullptr` -> range of `*a = nullptr`.
128 /// FIMXE: \p Name must be a source range, e.g. for `operator+`.
129 static SourceRange getDeclaratorRange(const SourceManager &SM, TypeLoc T,
130                                       SourceLocation Name,
131                                       SourceRange Initializer) {
132   SourceLocation Start = GetStartLoc().Visit(T);
133   SourceLocation End = T.getSourceRange().getEnd();
134   assert(End.isValid());
135   if (Name.isValid()) {
136     if (Start.isInvalid())
137       Start = Name;
138     if (SM.isBeforeInTranslationUnit(End, Name))
139       End = Name;
140   }
141   if (Initializer.isValid()) {
142     assert(SM.isBeforeInTranslationUnit(End, Initializer.getEnd()));
143     End = Initializer.getEnd();
144   }
145   return SourceRange(Start, End);
146 }
147 
148 /// A helper class for constructing the syntax tree while traversing a clang
149 /// AST.
150 ///
151 /// At each point of the traversal we maintain a list of pending nodes.
152 /// Initially all tokens are added as pending nodes. When processing a clang AST
153 /// node, the clients need to:
154 ///   - create a corresponding syntax node,
155 ///   - assign roles to all pending child nodes with 'markChild' and
156 ///     'markChildToken',
157 ///   - replace the child nodes with the new syntax node in the pending list
158 ///     with 'foldNode'.
159 ///
160 /// Note that all children are expected to be processed when building a node.
161 ///
162 /// Call finalize() to finish building the tree and consume the root node.
163 class syntax::TreeBuilder {
164 public:
165   TreeBuilder(syntax::Arena &Arena) : Arena(Arena), Pending(Arena) {
166     for (const auto &T : Arena.tokenBuffer().expandedTokens())
167       LocationToToken.insert({T.location().getRawEncoding(), &T});
168   }
169 
170   llvm::BumpPtrAllocator &allocator() { return Arena.allocator(); }
171   const SourceManager &sourceManager() const { return Arena.sourceManager(); }
172 
173   /// Populate children for \p New node, assuming it covers tokens from \p
174   /// Range.
175   void foldNode(llvm::ArrayRef<syntax::Token> Range, syntax::Tree *New);
176 
177   /// Must be called with the range of each `DeclaratorDecl`. Ensures the
178   /// corresponding declarator nodes are covered by `SimpleDeclaration`.
179   void noticeDeclRange(llvm::ArrayRef<syntax::Token> Range);
180 
181   /// Notifies that we should not consume trailing semicolon when computing
182   /// token range of \p D.
183   void noticeDeclWithoutSemicolon(Decl *D);
184 
185   /// Mark the \p Child node with a corresponding \p Role. All marked children
186   /// should be consumed by foldNode.
187   /// When called on expressions (clang::Expr is derived from clang::Stmt),
188   /// wraps expressions into expression statement.
189   void markStmtChild(Stmt *Child, NodeRole Role);
190   /// Should be called for expressions in non-statement position to avoid
191   /// wrapping into expression statement.
192   void markExprChild(Expr *Child, NodeRole Role);
193   /// Set role for a token starting at \p Loc.
194   void markChildToken(SourceLocation Loc, NodeRole R);
195   /// Set role for \p T.
196   void markChildToken(const syntax::Token *T, NodeRole R);
197 
198   /// Set role for the node that spans exactly \p Range.
199   void markChild(llvm::ArrayRef<syntax::Token> Range, NodeRole R);
200   /// Set role for the delayed node that spans exactly \p Range.
201   void markDelayedChild(llvm::ArrayRef<syntax::Token> Range, NodeRole R);
202   /// Set role for the node that may or may not be delayed. Node must span
203   /// exactly \p Range.
204   void markMaybeDelayedChild(llvm::ArrayRef<syntax::Token> Range, NodeRole R);
205 
206   /// Finish building the tree and consume the root node.
207   syntax::TranslationUnit *finalize() && {
208     auto Tokens = Arena.tokenBuffer().expandedTokens();
209     assert(!Tokens.empty());
210     assert(Tokens.back().kind() == tok::eof);
211 
212     // Build the root of the tree, consuming all the children.
213     Pending.foldChildren(Arena, Tokens.drop_back(),
214                          new (Arena.allocator()) syntax::TranslationUnit);
215 
216     auto *TU = cast<syntax::TranslationUnit>(std::move(Pending).finalize());
217     TU->assertInvariantsRecursive();
218     return TU;
219   }
220 
221   /// Finds a token starting at \p L. The token must exist if \p L is valid.
222   const syntax::Token *findToken(SourceLocation L) const;
223 
224   /// getRange() finds the syntax tokens corresponding to the passed source
225   /// locations.
226   /// \p First is the start position of the first token and \p Last is the start
227   /// position of the last token.
228   llvm::ArrayRef<syntax::Token> getRange(SourceLocation First,
229                                          SourceLocation Last) const {
230     assert(First.isValid());
231     assert(Last.isValid());
232     assert(First == Last ||
233            Arena.sourceManager().isBeforeInTranslationUnit(First, Last));
234     return llvm::makeArrayRef(findToken(First), std::next(findToken(Last)));
235   }
236 
237   llvm::ArrayRef<syntax::Token>
238   getTemplateRange(const ClassTemplateSpecializationDecl *D) const {
239     auto R = D->getSourceRange();
240     auto Tokens = getRange(R.getBegin(), R.getEnd());
241     return maybeAppendSemicolon(Tokens, D);
242   }
243 
244   llvm::ArrayRef<syntax::Token> getDeclRange(const Decl *D) const {
245     llvm::ArrayRef<clang::syntax::Token> Tokens;
246     // We want to drop the template parameters for specializations.
247     if (const auto *S = llvm::dyn_cast<TagDecl>(D))
248       Tokens = getRange(S->TypeDecl::getBeginLoc(), S->getEndLoc());
249     else
250       Tokens = getRange(D->getBeginLoc(), D->getEndLoc());
251     return maybeAppendSemicolon(Tokens, D);
252   }
253   llvm::ArrayRef<syntax::Token> getExprRange(const Expr *E) const {
254     return getRange(E->getBeginLoc(), E->getEndLoc());
255   }
256   /// Find the adjusted range for the statement, consuming the trailing
257   /// semicolon when needed.
258   llvm::ArrayRef<syntax::Token> getStmtRange(const Stmt *S) const {
259     auto Tokens = getRange(S->getBeginLoc(), S->getEndLoc());
260     if (isa<CompoundStmt>(S))
261       return Tokens;
262 
263     // Some statements miss a trailing semicolon, e.g. 'return', 'continue' and
264     // all statements that end with those. Consume this semicolon here.
265     if (Tokens.back().kind() == tok::semi)
266       return Tokens;
267     return withTrailingSemicolon(Tokens);
268   }
269 
270 private:
271   llvm::ArrayRef<syntax::Token>
272   maybeAppendSemicolon(llvm::ArrayRef<syntax::Token> Tokens,
273                        const Decl *D) const {
274     if (llvm::isa<NamespaceDecl>(D))
275       return Tokens;
276     if (DeclsWithoutSemicolons.count(D))
277       return Tokens;
278     // FIXME: do not consume trailing semicolon on function definitions.
279     // Most declarations own a semicolon in syntax trees, but not in clang AST.
280     return withTrailingSemicolon(Tokens);
281   }
282 
283   llvm::ArrayRef<syntax::Token>
284   withTrailingSemicolon(llvm::ArrayRef<syntax::Token> Tokens) const {
285     assert(!Tokens.empty());
286     assert(Tokens.back().kind() != tok::eof);
287     // We never consume 'eof', so looking at the next token is ok.
288     if (Tokens.back().kind() != tok::semi && Tokens.end()->kind() == tok::semi)
289       return llvm::makeArrayRef(Tokens.begin(), Tokens.end() + 1);
290     return Tokens;
291   }
292 
293   /// A collection of trees covering the input tokens.
294   /// When created, each tree corresponds to a single token in the file.
295   /// Clients call 'foldChildren' to attach one or more subtrees to a parent
296   /// node and update the list of trees accordingly.
297   ///
298   /// Ensures that added nodes properly nest and cover the whole token stream.
299   struct Forest {
300     Forest(syntax::Arena &A) {
301       assert(!A.tokenBuffer().expandedTokens().empty());
302       assert(A.tokenBuffer().expandedTokens().back().kind() == tok::eof);
303       // Create all leaf nodes.
304       // Note that we do not have 'eof' in the tree.
305       for (auto &T : A.tokenBuffer().expandedTokens().drop_back()) {
306         auto *L = new (A.allocator()) syntax::Leaf(&T);
307         L->Original = true;
308         L->CanModify = A.tokenBuffer().spelledForExpanded(T).hasValue();
309         Trees.insert(Trees.end(), {&T, NodeAndRole{L}});
310       }
311     }
312 
313     ~Forest() { assert(DelayedFolds.empty()); }
314 
315     void assignRoleDelayed(llvm::ArrayRef<syntax::Token> Range,
316                            syntax::NodeRole Role) {
317       auto It = DelayedFolds.find(Range.begin());
318       assert(It != DelayedFolds.end());
319       assert(It->second.End == Range.end());
320       It->second.Role = Role;
321     }
322 
323     void assignRoleMaybeDelayed(llvm::ArrayRef<syntax::Token> Range,
324                                 syntax::NodeRole Role) {
325       auto It = DelayedFolds.find(Range.begin());
326       if (It == DelayedFolds.end())
327         return assignRole(Range, Role);
328       assert(It->second.End == Range.end());
329       It->second.Role = Role;
330     }
331 
332     void assignRole(llvm::ArrayRef<syntax::Token> Range,
333                     syntax::NodeRole Role) {
334       assert(!Range.empty());
335       auto It = Trees.lower_bound(Range.begin());
336       assert(It != Trees.end() && "no node found");
337       assert(It->first == Range.begin() && "no child with the specified range");
338       assert((std::next(It) == Trees.end() ||
339               std::next(It)->first == Range.end()) &&
340              "no child with the specified range");
341       It->second.Role = Role;
342     }
343 
344     /// Add \p Node to the forest and attach child nodes based on \p Tokens.
345     void foldChildren(const syntax::Arena &A,
346                       llvm::ArrayRef<syntax::Token> Tokens,
347                       syntax::Tree *Node) {
348       // Execute delayed folds inside `Tokens`.
349       auto BeginFolds = DelayedFolds.lower_bound(Tokens.begin());
350       auto EndFolds = BeginFolds;
351       for (; EndFolds != DelayedFolds.end() &&
352              EndFolds->second.End <= Tokens.end();
353            ++EndFolds)
354         ;
355       // We go in reverse order to ensure we fold deeper nodes first.
356       for (auto RevIt = EndFolds; RevIt != BeginFolds; --RevIt) {
357         auto It = std::prev(RevIt);
358         foldChildrenEager(A, llvm::makeArrayRef(It->first, It->second.End),
359                           It->second.Node);
360       }
361       DelayedFolds.erase(BeginFolds, EndFolds);
362 
363       // Attach children to `Node`.
364       foldChildrenEager(A, Tokens, Node);
365     }
366 
367     /// Schedule a call to `foldChildren` that will only be executed when
368     /// containing node is folded. The range of delayed nodes can be extended by
369     /// calling `extendDelayedFold`. Only one delayed node for each starting
370     /// token is allowed.
371     void foldChildrenDelayed(llvm::ArrayRef<syntax::Token> Tokens,
372                              syntax::Tree *Node) {
373       assert(!Tokens.empty());
374       bool Inserted =
375           DelayedFolds.insert({Tokens.begin(), DelayedFold{Tokens.end(), Node}})
376               .second;
377       (void)Inserted;
378       assert(Inserted && "Multiple delayed folds start at the same token");
379     }
380 
381     /// If there a delayed fold, starting at `ExtendedRange.begin()`, extends
382     /// its endpoint to `ExtendedRange.end()` and returns true.
383     /// Otherwise, returns false.
384     bool extendDelayedFold(llvm::ArrayRef<syntax::Token> ExtendedRange) {
385       assert(!ExtendedRange.empty());
386       auto It = DelayedFolds.find(ExtendedRange.data());
387       if (It == DelayedFolds.end())
388         return false;
389       assert(It->second.End <= ExtendedRange.end());
390       It->second.End = ExtendedRange.end();
391       return true;
392     }
393 
394     // EXPECTS: all tokens were consumed and are owned by a single root node.
395     syntax::Node *finalize() && {
396       assert(Trees.size() == 1);
397       auto *Root = Trees.begin()->second.Node;
398       Trees = {};
399       return Root;
400     }
401 
402     std::string str(const syntax::Arena &A) const {
403       std::string R;
404       for (auto It = Trees.begin(); It != Trees.end(); ++It) {
405         unsigned CoveredTokens =
406             It != Trees.end()
407                 ? (std::next(It)->first - It->first)
408                 : A.tokenBuffer().expandedTokens().end() - It->first;
409 
410         R += std::string(llvm::formatv(
411             "- '{0}' covers '{1}'+{2} tokens\n", It->second.Node->kind(),
412             It->first->text(A.sourceManager()), CoveredTokens));
413         R += It->second.Node->dump(A);
414       }
415       return R;
416     }
417 
418   private:
419     /// Implementation detail of `foldChildren`, does acutal folding ignoring
420     /// delayed folds.
421     void foldChildrenEager(const syntax::Arena &A,
422                            llvm::ArrayRef<syntax::Token> Tokens,
423                            syntax::Tree *Node) {
424       assert(Node->firstChild() == nullptr && "node already has children");
425 
426       auto *FirstToken = Tokens.begin();
427       auto BeginChildren = Trees.lower_bound(FirstToken);
428       assert((BeginChildren == Trees.end() ||
429               BeginChildren->first == FirstToken) &&
430              "fold crosses boundaries of existing subtrees");
431       auto EndChildren = Trees.lower_bound(Tokens.end());
432       assert(
433           (EndChildren == Trees.end() || EndChildren->first == Tokens.end()) &&
434           "fold crosses boundaries of existing subtrees");
435 
436       // We need to go in reverse order, because we can only prepend.
437       for (auto It = EndChildren; It != BeginChildren; --It)
438         Node->prependChildLowLevel(std::prev(It)->second.Node,
439                                    std::prev(It)->second.Role);
440 
441       // Mark that this node came from the AST and is backed by the source code.
442       Node->Original = true;
443       Node->CanModify = A.tokenBuffer().spelledForExpanded(Tokens).hasValue();
444 
445       Trees.erase(BeginChildren, EndChildren);
446       Trees.insert({FirstToken, NodeAndRole(Node)});
447     }
448     /// A with a role that should be assigned to it when adding to a parent.
449     struct NodeAndRole {
450       explicit NodeAndRole(syntax::Node *Node)
451           : Node(Node), Role(NodeRole::Unknown) {}
452 
453       syntax::Node *Node;
454       NodeRole Role;
455     };
456 
457     /// Maps from the start token to a subtree starting at that token.
458     /// Keys in the map are pointers into the array of expanded tokens, so
459     /// pointer order corresponds to the order of preprocessor tokens.
460     /// FIXME: storing the end tokens is redundant.
461     /// FIXME: the key of a map is redundant, it is also stored in NodeForRange.
462     std::map<const syntax::Token *, NodeAndRole> Trees;
463 
464     /// See documentation of `foldChildrenDelayed` for details.
465     struct DelayedFold {
466       const syntax::Token *End = nullptr;
467       syntax::Tree *Node = nullptr;
468       NodeRole Role = NodeRole::Unknown;
469     };
470     std::map<const syntax::Token *, DelayedFold> DelayedFolds;
471   };
472 
473   /// For debugging purposes.
474   std::string str() { return Pending.str(Arena); }
475 
476   syntax::Arena &Arena;
477   /// To quickly find tokens by their start location.
478   llvm::DenseMap</*SourceLocation*/ unsigned, const syntax::Token *>
479       LocationToToken;
480   Forest Pending;
481   llvm::DenseSet<Decl *> DeclsWithoutSemicolons;
482 };
483 
484 namespace {
485 class BuildTreeVisitor : public RecursiveASTVisitor<BuildTreeVisitor> {
486 public:
487   explicit BuildTreeVisitor(ASTContext &Ctx, syntax::TreeBuilder &Builder)
488       : Builder(Builder), LangOpts(Ctx.getLangOpts()) {}
489 
490   bool shouldTraversePostOrder() const { return true; }
491 
492   bool WalkUpFromDeclaratorDecl(DeclaratorDecl *DD) {
493     // Ensure declarators are covered by SimpleDeclaration.
494     Builder.noticeDeclRange(Builder.getDeclRange(DD));
495 
496     // Build the declarator node.
497     SourceRange Initializer;
498     if (auto *V = llvm::dyn_cast<VarDecl>(DD)) {
499       auto *I = V->getInit();
500       // Initializers in range-based-for are not part of the declarator
501       if (I && !V->isCXXForRangeDecl())
502         Initializer = I->getSourceRange();
503     }
504     auto Declarator = getDeclaratorRange(
505         Builder.sourceManager(), DD->getTypeSourceInfo()->getTypeLoc(),
506         getQualifiedNameStart(DD), Initializer);
507     if (Declarator.isValid()) {
508       auto Tokens =
509           Builder.getRange(Declarator.getBegin(), Declarator.getEnd());
510       Builder.foldNode(Tokens, new (allocator()) syntax::SimpleDeclarator);
511       Builder.markChild(Tokens, syntax::NodeRole::SimpleDeclaration_declarator);
512     }
513 
514     return true;
515   }
516 
517   bool WalkUpFromTypedefNameDecl(TypedefNameDecl *D) {
518     // Ensure declarators are covered by SimpleDeclaration.
519     Builder.noticeDeclRange(Builder.getDeclRange(D));
520 
521     auto R = getDeclaratorRange(
522         Builder.sourceManager(), D->getTypeSourceInfo()->getTypeLoc(),
523         /*Name=*/D->getLocation(), /*Initializer=*/SourceRange());
524     if (R.isValid()) {
525       auto Tokens = Builder.getRange(R.getBegin(), R.getEnd());
526       Builder.foldNode(Tokens, new (allocator()) syntax::SimpleDeclarator);
527       Builder.markChild(Tokens, syntax::NodeRole::SimpleDeclaration_declarator);
528     }
529     return true;
530   }
531 
532   bool VisitDecl(Decl *D) {
533     assert(!D->isImplicit());
534     Builder.foldNode(Builder.getDeclRange(D),
535                      new (allocator()) syntax::UnknownDeclaration());
536     return true;
537   }
538 
539   // RAV does not call WalkUpFrom* on explicit instantiations, so we have to
540   // override Traverse.
541   // FIXME: make RAV call WalkUpFrom* instead.
542   bool
543   TraverseClassTemplateSpecializationDecl(ClassTemplateSpecializationDecl *C) {
544     if (!RecursiveASTVisitor::TraverseClassTemplateSpecializationDecl(C))
545       return false;
546     if (C->isExplicitSpecialization())
547       return true; // we are only interested in explicit instantiations.
548     if (!WalkUpFromClassTemplateSpecializationDecl(C))
549       return false;
550     foldExplicitTemplateInstantiation(
551         Builder.getTemplateRange(C), Builder.findToken(C->getExternLoc()),
552         Builder.findToken(C->getTemplateKeywordLoc()), Builder.getDeclRange(C));
553     return true;
554   }
555 
556   bool WalkUpFromTemplateDecl(TemplateDecl *S) {
557     foldTemplateDeclaration(
558         Builder.getDeclRange(S),
559         Builder.findToken(S->getTemplateParameters()->getTemplateLoc()),
560         Builder.getDeclRange(S->getTemplatedDecl()));
561     return true;
562   }
563 
564   bool WalkUpFromTagDecl(TagDecl *C) {
565     // FIXME: build the ClassSpecifier node.
566     if (!C->isFreeStanding()) {
567       assert(C->getNumTemplateParameterLists() == 0);
568       return true;
569     }
570     // Class is a declaration specifier and needs a spanning declaration node.
571     auto DeclarationRange = Builder.getDeclRange(C);
572     Builder.foldNode(DeclarationRange,
573                      new (allocator()) syntax::SimpleDeclaration);
574 
575     // Build TemplateDeclaration nodes if we had template parameters.
576     auto ConsumeTemplateParameters = [&](const TemplateParameterList &L) {
577       const auto *TemplateKW = Builder.findToken(L.getTemplateLoc());
578       auto R = llvm::makeArrayRef(TemplateKW, DeclarationRange.end());
579       foldTemplateDeclaration(R, TemplateKW, DeclarationRange);
580 
581       DeclarationRange = R;
582     };
583     if (auto *S = llvm::dyn_cast<ClassTemplatePartialSpecializationDecl>(C))
584       ConsumeTemplateParameters(*S->getTemplateParameters());
585     for (unsigned I = C->getNumTemplateParameterLists(); 0 < I; --I)
586       ConsumeTemplateParameters(*C->getTemplateParameterList(I - 1));
587     return true;
588   }
589 
590   bool WalkUpFromTranslationUnitDecl(TranslationUnitDecl *TU) {
591     // We do not want to call VisitDecl(), the declaration for translation
592     // unit is built by finalize().
593     return true;
594   }
595 
596   bool WalkUpFromCompoundStmt(CompoundStmt *S) {
597     using NodeRole = syntax::NodeRole;
598 
599     Builder.markChildToken(S->getLBracLoc(), NodeRole::OpenParen);
600     for (auto *Child : S->body())
601       Builder.markStmtChild(Child, NodeRole::CompoundStatement_statement);
602     Builder.markChildToken(S->getRBracLoc(), NodeRole::CloseParen);
603 
604     Builder.foldNode(Builder.getStmtRange(S),
605                      new (allocator()) syntax::CompoundStatement);
606     return true;
607   }
608 
609   // Some statements are not yet handled by syntax trees.
610   bool WalkUpFromStmt(Stmt *S) {
611     Builder.foldNode(Builder.getStmtRange(S),
612                      new (allocator()) syntax::UnknownStatement);
613     return true;
614   }
615 
616   bool TraverseCXXForRangeStmt(CXXForRangeStmt *S) {
617     // We override to traverse range initializer as VarDecl.
618     // RAV traverses it as a statement, we produce invalid node kinds in that
619     // case.
620     // FIXME: should do this in RAV instead?
621     if (S->getInit() && !TraverseStmt(S->getInit()))
622       return false;
623     if (S->getLoopVariable() && !TraverseDecl(S->getLoopVariable()))
624       return false;
625     if (S->getRangeInit() && !TraverseStmt(S->getRangeInit()))
626       return false;
627     if (S->getBody() && !TraverseStmt(S->getBody()))
628       return false;
629     return true;
630   }
631 
632   bool TraverseStmt(Stmt *S) {
633     if (auto *DS = llvm::dyn_cast_or_null<DeclStmt>(S)) {
634       // We want to consume the semicolon, make sure SimpleDeclaration does not.
635       for (auto *D : DS->decls())
636         Builder.noticeDeclWithoutSemicolon(D);
637     } else if (auto *E = llvm::dyn_cast_or_null<Expr>(S)) {
638       // Do not recurse into subexpressions.
639       // We do not have syntax trees for expressions yet, so we only want to see
640       // the first top-level expression.
641       return WalkUpFromExpr(E->IgnoreImplicit());
642     }
643     return RecursiveASTVisitor::TraverseStmt(S);
644   }
645 
646   // Some expressions are not yet handled by syntax trees.
647   bool WalkUpFromExpr(Expr *E) {
648     assert(!isImplicitExpr(E) && "should be handled by TraverseStmt");
649     Builder.foldNode(Builder.getExprRange(E),
650                      new (allocator()) syntax::UnknownExpression);
651     return true;
652   }
653 
654   bool WalkUpFromNamespaceDecl(NamespaceDecl *S) {
655     auto Tokens = Builder.getDeclRange(S);
656     if (Tokens.front().kind() == tok::coloncolon) {
657       // Handle nested namespace definitions. Those start at '::' token, e.g.
658       // namespace a^::b {}
659       // FIXME: build corresponding nodes for the name of this namespace.
660       return true;
661     }
662     Builder.foldNode(Tokens, new (allocator()) syntax::NamespaceDefinition);
663     return true;
664   }
665 
666   bool TraverseParenTypeLoc(ParenTypeLoc L) {
667     // We reverse order of traversal to get the proper syntax structure.
668     if (!WalkUpFromParenTypeLoc(L))
669       return false;
670     return TraverseTypeLoc(L.getInnerLoc());
671   }
672 
673   bool WalkUpFromParenTypeLoc(ParenTypeLoc L) {
674     Builder.markChildToken(L.getLParenLoc(), syntax::NodeRole::OpenParen);
675     Builder.markChildToken(L.getRParenLoc(), syntax::NodeRole::CloseParen);
676     Builder.foldNode(Builder.getRange(L.getLParenLoc(), L.getRParenLoc()),
677                      new (allocator()) syntax::ParenDeclarator);
678     return true;
679   }
680 
681   // Declarator chunks, they are produced by type locs and some clang::Decls.
682   bool WalkUpFromArrayTypeLoc(ArrayTypeLoc L) {
683     Builder.markChildToken(L.getLBracketLoc(), syntax::NodeRole::OpenParen);
684     Builder.markExprChild(L.getSizeExpr(),
685                           syntax::NodeRole::ArraySubscript_sizeExpression);
686     Builder.markChildToken(L.getRBracketLoc(), syntax::NodeRole::CloseParen);
687     Builder.foldNode(Builder.getRange(L.getLBracketLoc(), L.getRBracketLoc()),
688                      new (allocator()) syntax::ArraySubscript);
689     return true;
690   }
691 
692   bool WalkUpFromFunctionTypeLoc(FunctionTypeLoc L) {
693     Builder.markChildToken(L.getLParenLoc(), syntax::NodeRole::OpenParen);
694     for (auto *P : L.getParams())
695       Builder.markDelayedChild(
696           Builder.getDeclRange(P),
697           syntax::NodeRole::ParametersAndQualifiers_parameter);
698     Builder.markChildToken(L.getRParenLoc(), syntax::NodeRole::CloseParen);
699     Builder.foldNode(Builder.getRange(L.getLParenLoc(), L.getEndLoc()),
700                      new (allocator()) syntax::ParametersAndQualifiers);
701     return true;
702   }
703 
704   bool WalkUpFromFunctionProtoTypeLoc(FunctionProtoTypeLoc L) {
705     if (!L.getTypePtr()->hasTrailingReturn())
706       return WalkUpFromFunctionTypeLoc(L);
707 
708     auto TrailingReturnTokens = BuildTrailingReturn(L);
709     // Finish building the node for parameters.
710     Builder.markChild(TrailingReturnTokens,
711                       syntax::NodeRole::ParametersAndQualifiers_trailingReturn);
712     return WalkUpFromFunctionTypeLoc(L);
713   }
714 
715   bool WalkUpFromMemberPointerTypeLoc(MemberPointerTypeLoc L) {
716     auto SR = L.getLocalSourceRange();
717     Builder.foldNode(Builder.getRange(SR.getBegin(), SR.getEnd()),
718                      new (allocator()) syntax::MemberPointer);
719     return true;
720   }
721 
722   // The code below is very regular, it could even be generated with some
723   // preprocessor magic. We merely assign roles to the corresponding children
724   // and fold resulting nodes.
725   bool WalkUpFromDeclStmt(DeclStmt *S) {
726     Builder.foldNode(Builder.getStmtRange(S),
727                      new (allocator()) syntax::DeclarationStatement);
728     return true;
729   }
730 
731   bool WalkUpFromNullStmt(NullStmt *S) {
732     Builder.foldNode(Builder.getStmtRange(S),
733                      new (allocator()) syntax::EmptyStatement);
734     return true;
735   }
736 
737   bool WalkUpFromSwitchStmt(SwitchStmt *S) {
738     Builder.markChildToken(S->getSwitchLoc(),
739                            syntax::NodeRole::IntroducerKeyword);
740     Builder.markStmtChild(S->getBody(), syntax::NodeRole::BodyStatement);
741     Builder.foldNode(Builder.getStmtRange(S),
742                      new (allocator()) syntax::SwitchStatement);
743     return true;
744   }
745 
746   bool WalkUpFromCaseStmt(CaseStmt *S) {
747     Builder.markChildToken(S->getKeywordLoc(),
748                            syntax::NodeRole::IntroducerKeyword);
749     Builder.markExprChild(S->getLHS(), syntax::NodeRole::CaseStatement_value);
750     Builder.markStmtChild(S->getSubStmt(), syntax::NodeRole::BodyStatement);
751     Builder.foldNode(Builder.getStmtRange(S),
752                      new (allocator()) syntax::CaseStatement);
753     return true;
754   }
755 
756   bool WalkUpFromDefaultStmt(DefaultStmt *S) {
757     Builder.markChildToken(S->getKeywordLoc(),
758                            syntax::NodeRole::IntroducerKeyword);
759     Builder.markStmtChild(S->getSubStmt(), syntax::NodeRole::BodyStatement);
760     Builder.foldNode(Builder.getStmtRange(S),
761                      new (allocator()) syntax::DefaultStatement);
762     return true;
763   }
764 
765   bool WalkUpFromIfStmt(IfStmt *S) {
766     Builder.markChildToken(S->getIfLoc(), syntax::NodeRole::IntroducerKeyword);
767     Builder.markStmtChild(S->getThen(),
768                           syntax::NodeRole::IfStatement_thenStatement);
769     Builder.markChildToken(S->getElseLoc(),
770                            syntax::NodeRole::IfStatement_elseKeyword);
771     Builder.markStmtChild(S->getElse(),
772                           syntax::NodeRole::IfStatement_elseStatement);
773     Builder.foldNode(Builder.getStmtRange(S),
774                      new (allocator()) syntax::IfStatement);
775     return true;
776   }
777 
778   bool WalkUpFromForStmt(ForStmt *S) {
779     Builder.markChildToken(S->getForLoc(), syntax::NodeRole::IntroducerKeyword);
780     Builder.markStmtChild(S->getBody(), syntax::NodeRole::BodyStatement);
781     Builder.foldNode(Builder.getStmtRange(S),
782                      new (allocator()) syntax::ForStatement);
783     return true;
784   }
785 
786   bool WalkUpFromWhileStmt(WhileStmt *S) {
787     Builder.markChildToken(S->getWhileLoc(),
788                            syntax::NodeRole::IntroducerKeyword);
789     Builder.markStmtChild(S->getBody(), syntax::NodeRole::BodyStatement);
790     Builder.foldNode(Builder.getStmtRange(S),
791                      new (allocator()) syntax::WhileStatement);
792     return true;
793   }
794 
795   bool WalkUpFromContinueStmt(ContinueStmt *S) {
796     Builder.markChildToken(S->getContinueLoc(),
797                            syntax::NodeRole::IntroducerKeyword);
798     Builder.foldNode(Builder.getStmtRange(S),
799                      new (allocator()) syntax::ContinueStatement);
800     return true;
801   }
802 
803   bool WalkUpFromBreakStmt(BreakStmt *S) {
804     Builder.markChildToken(S->getBreakLoc(),
805                            syntax::NodeRole::IntroducerKeyword);
806     Builder.foldNode(Builder.getStmtRange(S),
807                      new (allocator()) syntax::BreakStatement);
808     return true;
809   }
810 
811   bool WalkUpFromReturnStmt(ReturnStmt *S) {
812     Builder.markChildToken(S->getReturnLoc(),
813                            syntax::NodeRole::IntroducerKeyword);
814     Builder.markExprChild(S->getRetValue(),
815                           syntax::NodeRole::ReturnStatement_value);
816     Builder.foldNode(Builder.getStmtRange(S),
817                      new (allocator()) syntax::ReturnStatement);
818     return true;
819   }
820 
821   bool WalkUpFromCXXForRangeStmt(CXXForRangeStmt *S) {
822     Builder.markChildToken(S->getForLoc(), syntax::NodeRole::IntroducerKeyword);
823     Builder.markStmtChild(S->getBody(), syntax::NodeRole::BodyStatement);
824     Builder.foldNode(Builder.getStmtRange(S),
825                      new (allocator()) syntax::RangeBasedForStatement);
826     return true;
827   }
828 
829   bool WalkUpFromEmptyDecl(EmptyDecl *S) {
830     Builder.foldNode(Builder.getDeclRange(S),
831                      new (allocator()) syntax::EmptyDeclaration);
832     return true;
833   }
834 
835   bool WalkUpFromStaticAssertDecl(StaticAssertDecl *S) {
836     Builder.markExprChild(S->getAssertExpr(),
837                           syntax::NodeRole::StaticAssertDeclaration_condition);
838     Builder.markExprChild(S->getMessage(),
839                           syntax::NodeRole::StaticAssertDeclaration_message);
840     Builder.foldNode(Builder.getDeclRange(S),
841                      new (allocator()) syntax::StaticAssertDeclaration);
842     return true;
843   }
844 
845   bool WalkUpFromLinkageSpecDecl(LinkageSpecDecl *S) {
846     Builder.foldNode(Builder.getDeclRange(S),
847                      new (allocator()) syntax::LinkageSpecificationDeclaration);
848     return true;
849   }
850 
851   bool WalkUpFromNamespaceAliasDecl(NamespaceAliasDecl *S) {
852     Builder.foldNode(Builder.getDeclRange(S),
853                      new (allocator()) syntax::NamespaceAliasDefinition);
854     return true;
855   }
856 
857   bool WalkUpFromUsingDirectiveDecl(UsingDirectiveDecl *S) {
858     Builder.foldNode(Builder.getDeclRange(S),
859                      new (allocator()) syntax::UsingNamespaceDirective);
860     return true;
861   }
862 
863   bool WalkUpFromUsingDecl(UsingDecl *S) {
864     Builder.foldNode(Builder.getDeclRange(S),
865                      new (allocator()) syntax::UsingDeclaration);
866     return true;
867   }
868 
869   bool WalkUpFromUnresolvedUsingValueDecl(UnresolvedUsingValueDecl *S) {
870     Builder.foldNode(Builder.getDeclRange(S),
871                      new (allocator()) syntax::UsingDeclaration);
872     return true;
873   }
874 
875   bool WalkUpFromUnresolvedUsingTypenameDecl(UnresolvedUsingTypenameDecl *S) {
876     Builder.foldNode(Builder.getDeclRange(S),
877                      new (allocator()) syntax::UsingDeclaration);
878     return true;
879   }
880 
881   bool WalkUpFromTypeAliasDecl(TypeAliasDecl *S) {
882     Builder.foldNode(Builder.getDeclRange(S),
883                      new (allocator()) syntax::TypeAliasDeclaration);
884     return true;
885   }
886 
887 private:
888   /// Returns the range of the built node.
889   llvm::ArrayRef<syntax::Token> BuildTrailingReturn(FunctionProtoTypeLoc L) {
890     assert(L.getTypePtr()->hasTrailingReturn());
891 
892     auto ReturnedType = L.getReturnLoc();
893     // Build node for the declarator, if any.
894     auto ReturnDeclaratorRange =
895         getDeclaratorRange(this->Builder.sourceManager(), ReturnedType,
896                            /*Name=*/SourceLocation(),
897                            /*Initializer=*/SourceLocation());
898     llvm::ArrayRef<syntax::Token> ReturnDeclaratorTokens;
899     if (ReturnDeclaratorRange.isValid()) {
900       ReturnDeclaratorTokens = Builder.getRange(
901           ReturnDeclaratorRange.getBegin(), ReturnDeclaratorRange.getEnd());
902       Builder.foldNode(ReturnDeclaratorTokens,
903                        new (allocator()) syntax::SimpleDeclarator);
904     }
905 
906     // Build node for trailing return type.
907     auto Return =
908         Builder.getRange(ReturnedType.getBeginLoc(), ReturnedType.getEndLoc());
909     const auto *Arrow = Return.begin() - 1;
910     assert(Arrow->kind() == tok::arrow);
911     auto Tokens = llvm::makeArrayRef(Arrow, Return.end());
912     Builder.markChildToken(Arrow, syntax::NodeRole::TrailingReturnType_arrow);
913     if (!ReturnDeclaratorTokens.empty())
914       Builder.markChild(ReturnDeclaratorTokens,
915                         syntax::NodeRole::TrailingReturnType_declarator);
916     Builder.foldNode(Tokens, new (allocator()) syntax::TrailingReturnType);
917     return Tokens;
918   }
919 
920   void
921   foldExplicitTemplateInstantiation(ArrayRef<syntax::Token> Range,
922                                     const syntax::Token *ExternKW,
923                                     const syntax::Token *TemplateKW,
924                                     ArrayRef<syntax::Token> InnerDeclaration) {
925     assert(!ExternKW || ExternKW->kind() == tok::kw_extern);
926     assert(TemplateKW && TemplateKW->kind() == tok::kw_template);
927     Builder.markChildToken(
928         ExternKW,
929         syntax::NodeRole::ExplicitTemplateInstantiation_externKeyword);
930     Builder.markChildToken(TemplateKW, syntax::NodeRole::IntroducerKeyword);
931     Builder.markChild(
932         InnerDeclaration,
933         syntax::NodeRole::ExplicitTemplateInstantiation_declaration);
934     Builder.foldNode(Range,
935                      new (allocator()) syntax::ExplicitTemplateInstantiation);
936   }
937 
938   void foldTemplateDeclaration(ArrayRef<syntax::Token> Range,
939                                const syntax::Token *TemplateKW,
940                                ArrayRef<syntax::Token> TemplatedDeclaration) {
941     assert(TemplateKW && TemplateKW->kind() == tok::kw_template);
942     Builder.markChildToken(TemplateKW, syntax::NodeRole::IntroducerKeyword);
943     Builder.markMaybeDelayedChild(
944         TemplatedDeclaration,
945         syntax::NodeRole::TemplateDeclaration_declaration);
946     Builder.foldNode(Range, new (allocator()) syntax::TemplateDeclaration);
947   }
948 
949   /// A small helper to save some typing.
950   llvm::BumpPtrAllocator &allocator() { return Builder.allocator(); }
951 
952   syntax::TreeBuilder &Builder;
953   const LangOptions &LangOpts;
954 };
955 } // namespace
956 
957 void syntax::TreeBuilder::foldNode(llvm::ArrayRef<syntax::Token> Range,
958                                    syntax::Tree *New) {
959   Pending.foldChildren(Arena, Range, New);
960 }
961 
962 void syntax::TreeBuilder::noticeDeclRange(llvm::ArrayRef<syntax::Token> Range) {
963   if (Pending.extendDelayedFold(Range))
964     return;
965   Pending.foldChildrenDelayed(Range,
966                               new (allocator()) syntax::SimpleDeclaration);
967 }
968 
969 void syntax::TreeBuilder::noticeDeclWithoutSemicolon(Decl *D) {
970   DeclsWithoutSemicolons.insert(D);
971 }
972 
973 void syntax::TreeBuilder::markChildToken(SourceLocation Loc, NodeRole Role) {
974   if (Loc.isInvalid())
975     return;
976   Pending.assignRole(*findToken(Loc), Role);
977 }
978 
979 void syntax::TreeBuilder::markChildToken(const syntax::Token *T, NodeRole R) {
980   if (!T)
981     return;
982   Pending.assignRole(*T, R);
983 }
984 
985 void syntax::TreeBuilder::markChild(llvm::ArrayRef<syntax::Token> Range,
986                                     NodeRole R) {
987   Pending.assignRole(Range, R);
988 }
989 
990 void syntax::TreeBuilder::markDelayedChild(llvm::ArrayRef<syntax::Token> Range,
991                                            NodeRole R) {
992   Pending.assignRoleDelayed(Range, R);
993 }
994 
995 void syntax::TreeBuilder::markMaybeDelayedChild(
996     llvm::ArrayRef<syntax::Token> Range, NodeRole R) {
997   Pending.assignRoleMaybeDelayed(Range, R);
998 }
999 
1000 void syntax::TreeBuilder::markStmtChild(Stmt *Child, NodeRole Role) {
1001   if (!Child)
1002     return;
1003 
1004   auto Range = getStmtRange(Child);
1005   // This is an expression in a statement position, consume the trailing
1006   // semicolon and form an 'ExpressionStatement' node.
1007   if (auto *E = dyn_cast<Expr>(Child)) {
1008     Pending.assignRole(getExprRange(E),
1009                        NodeRole::ExpressionStatement_expression);
1010     // 'getRange(Stmt)' ensures this already covers a trailing semicolon.
1011     Pending.foldChildren(Arena, Range,
1012                          new (allocator()) syntax::ExpressionStatement);
1013   }
1014   Pending.assignRole(Range, Role);
1015 }
1016 
1017 void syntax::TreeBuilder::markExprChild(Expr *Child, NodeRole Role) {
1018   if (!Child)
1019     return;
1020 
1021   Pending.assignRole(getExprRange(Child), Role);
1022 }
1023 
1024 const syntax::Token *syntax::TreeBuilder::findToken(SourceLocation L) const {
1025   if (L.isInvalid())
1026     return nullptr;
1027   auto It = LocationToToken.find(L.getRawEncoding());
1028   assert(It != LocationToToken.end());
1029   return It->second;
1030 }
1031 
1032 syntax::TranslationUnit *
1033 syntax::buildSyntaxTree(Arena &A, const TranslationUnitDecl &TU) {
1034   TreeBuilder Builder(A);
1035   BuildTreeVisitor(TU.getASTContext(), Builder).TraverseAST(TU.getASTContext());
1036   return std::move(Builder).finalize();
1037 }
1038