1 //===- unittest/AST/RecursiveASTVisitorTest.cpp ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "clang/AST/RecursiveASTVisitor.h"
10 #include "clang/AST/ASTConsumer.h"
11 #include "clang/AST/ASTContext.h"
12 #include "clang/AST/Attr.h"
13 #include "clang/AST/Decl.h"
14 #include "clang/AST/TypeLoc.h"
15 #include "clang/Frontend/FrontendAction.h"
16 #include "clang/Tooling/Tooling.h"
17 #include "llvm/ADT/FunctionExtras.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "gmock/gmock.h"
20 #include "gtest/gtest.h"
21 #include <cassert>
22 
23 using namespace clang;
24 using ::testing::ElementsAre;
25 
26 namespace {
27 class ProcessASTAction : public clang::ASTFrontendAction {
28 public:
29   ProcessASTAction(llvm::unique_function<void(clang::ASTContext &)> Process)
30       : Process(std::move(Process)) {
31     assert(this->Process);
32   }
33 
34   std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI,
35                                                  StringRef InFile) {
36     class Consumer : public ASTConsumer {
37     public:
38       Consumer(llvm::function_ref<void(ASTContext &CTx)> Process)
39           : Process(Process) {}
40 
41       void HandleTranslationUnit(ASTContext &Ctx) override { Process(Ctx); }
42 
43     private:
44       llvm::function_ref<void(ASTContext &CTx)> Process;
45     };
46 
47     return std::make_unique<Consumer>(Process);
48   }
49 
50 private:
51   llvm::unique_function<void(clang::ASTContext &)> Process;
52 };
53 
54 enum class VisitEvent {
55   StartTraverseFunction,
56   EndTraverseFunction,
57   StartTraverseAttr,
58   EndTraverseAttr,
59   StartTraverseEnum,
60   EndTraverseEnum,
61   StartTraverseTypedefType,
62   EndTraverseTypedefType,
63 };
64 
65 class CollectInterestingEvents
66     : public RecursiveASTVisitor<CollectInterestingEvents> {
67 public:
68   bool TraverseFunctionDecl(FunctionDecl *D) {
69     Events.push_back(VisitEvent::StartTraverseFunction);
70     bool Ret = RecursiveASTVisitor::TraverseFunctionDecl(D);
71     Events.push_back(VisitEvent::EndTraverseFunction);
72 
73     return Ret;
74   }
75 
76   bool TraverseAttr(Attr *A) {
77     Events.push_back(VisitEvent::StartTraverseAttr);
78     bool Ret = RecursiveASTVisitor::TraverseAttr(A);
79     Events.push_back(VisitEvent::EndTraverseAttr);
80 
81     return Ret;
82   }
83 
84   bool TraverseEnumDecl(EnumDecl *D) {
85     Events.push_back(VisitEvent::StartTraverseEnum);
86     bool Ret = RecursiveASTVisitor::TraverseEnumDecl(D);
87     Events.push_back(VisitEvent::EndTraverseEnum);
88 
89     return Ret;
90   }
91 
92   bool TraverseTypedefTypeLoc(TypedefTypeLoc TL) {
93     Events.push_back(VisitEvent::StartTraverseTypedefType);
94     bool Ret = RecursiveASTVisitor::TraverseTypedefTypeLoc(TL);
95     Events.push_back(VisitEvent::EndTraverseTypedefType);
96 
97     return Ret;
98   }
99 
100   std::vector<VisitEvent> takeEvents() && { return std::move(Events); }
101 
102 private:
103   std::vector<VisitEvent> Events;
104 };
105 
106 std::vector<VisitEvent> collectEvents(llvm::StringRef Code) {
107   CollectInterestingEvents Visitor;
108   clang::tooling::runToolOnCode(
109       std::make_unique<ProcessASTAction>(
110           [&](clang::ASTContext &Ctx) { Visitor.TraverseAST(Ctx); }),
111       Code);
112   return std::move(Visitor).takeEvents();
113 }
114 } // namespace
115 
116 TEST(RecursiveASTVisitorTest, AttributesInsideDecls) {
117   /// Check attributes are traversed inside TraverseFunctionDecl.
118   llvm::StringRef Code = R"cpp(
119 __attribute__((annotate("something"))) int foo() { return 10; }
120   )cpp";
121 
122   EXPECT_THAT(collectEvents(Code),
123               ElementsAre(VisitEvent::StartTraverseFunction,
124                           VisitEvent::StartTraverseAttr,
125                           VisitEvent::EndTraverseAttr,
126                           VisitEvent::EndTraverseFunction));
127 }
128 
129 TEST(RecursiveASTVisitorTest, EnumDeclWithBase) {
130   // Check enum and its integer base is visited.
131   llvm::StringRef Code = R"cpp(
132   typedef int Foo;
133   enum Bar : Foo;
134   )cpp";
135 
136   EXPECT_THAT(collectEvents(Code),
137               ElementsAre(VisitEvent::StartTraverseEnum,
138                           VisitEvent::StartTraverseTypedefType,
139                           VisitEvent::EndTraverseTypedefType,
140                           VisitEvent::EndTraverseEnum));
141 }
142