1 //===- unittest/AST/RecursiveASTVisitorTest.cpp ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "clang/AST/RecursiveASTVisitor.h" 10 #include "clang/AST/ASTConsumer.h" 11 #include "clang/AST/ASTContext.h" 12 #include "clang/AST/Attr.h" 13 #include "clang/AST/Decl.h" 14 #include "clang/AST/TypeLoc.h" 15 #include "clang/Frontend/FrontendAction.h" 16 #include "clang/Tooling/Tooling.h" 17 #include "llvm/ADT/FunctionExtras.h" 18 #include "llvm/ADT/STLExtras.h" 19 #include "gmock/gmock.h" 20 #include "gtest/gtest.h" 21 #include <cassert> 22 23 using namespace clang; 24 using ::testing::ElementsAre; 25 26 namespace { 27 class ProcessASTAction : public clang::ASTFrontendAction { 28 public: 29 ProcessASTAction(llvm::unique_function<void(clang::ASTContext &)> Process) 30 : Process(std::move(Process)) { 31 assert(this->Process); 32 } 33 34 std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI, 35 StringRef InFile) { 36 class Consumer : public ASTConsumer { 37 public: 38 Consumer(llvm::function_ref<void(ASTContext &CTx)> Process) 39 : Process(Process) {} 40 41 void HandleTranslationUnit(ASTContext &Ctx) override { Process(Ctx); } 42 43 private: 44 llvm::function_ref<void(ASTContext &CTx)> Process; 45 }; 46 47 return std::make_unique<Consumer>(Process); 48 } 49 50 private: 51 llvm::unique_function<void(clang::ASTContext &)> Process; 52 }; 53 54 enum class VisitEvent { 55 StartTraverseFunction, 56 EndTraverseFunction, 57 StartTraverseAttr, 58 EndTraverseAttr, 59 StartTraverseEnum, 60 EndTraverseEnum, 61 StartTraverseTypedefType, 62 EndTraverseTypedefType, 63 }; 64 65 class CollectInterestingEvents 66 : public RecursiveASTVisitor<CollectInterestingEvents> { 67 public: 68 bool TraverseFunctionDecl(FunctionDecl *D) { 69 Events.push_back(VisitEvent::StartTraverseFunction); 70 bool Ret = RecursiveASTVisitor::TraverseFunctionDecl(D); 71 Events.push_back(VisitEvent::EndTraverseFunction); 72 73 return Ret; 74 } 75 76 bool TraverseAttr(Attr *A) { 77 Events.push_back(VisitEvent::StartTraverseAttr); 78 bool Ret = RecursiveASTVisitor::TraverseAttr(A); 79 Events.push_back(VisitEvent::EndTraverseAttr); 80 81 return Ret; 82 } 83 84 bool TraverseEnumDecl(EnumDecl *D) { 85 Events.push_back(VisitEvent::StartTraverseEnum); 86 bool Ret = RecursiveASTVisitor::TraverseEnumDecl(D); 87 Events.push_back(VisitEvent::EndTraverseEnum); 88 89 return Ret; 90 } 91 92 bool TraverseTypedefTypeLoc(TypedefTypeLoc TL) { 93 Events.push_back(VisitEvent::StartTraverseTypedefType); 94 bool Ret = RecursiveASTVisitor::TraverseTypedefTypeLoc(TL); 95 Events.push_back(VisitEvent::EndTraverseTypedefType); 96 97 return Ret; 98 } 99 100 std::vector<VisitEvent> takeEvents() && { return std::move(Events); } 101 102 private: 103 std::vector<VisitEvent> Events; 104 }; 105 106 std::vector<VisitEvent> collectEvents(llvm::StringRef Code) { 107 CollectInterestingEvents Visitor; 108 clang::tooling::runToolOnCode( 109 std::make_unique<ProcessASTAction>( 110 [&](clang::ASTContext &Ctx) { Visitor.TraverseAST(Ctx); }), 111 Code); 112 return std::move(Visitor).takeEvents(); 113 } 114 } // namespace 115 116 TEST(RecursiveASTVisitorTest, AttributesInsideDecls) { 117 /// Check attributes are traversed inside TraverseFunctionDecl. 118 llvm::StringRef Code = R"cpp( 119 __attribute__((annotate("something"))) int foo() { return 10; } 120 )cpp"; 121 122 EXPECT_THAT(collectEvents(Code), 123 ElementsAre(VisitEvent::StartTraverseFunction, 124 VisitEvent::StartTraverseAttr, 125 VisitEvent::EndTraverseAttr, 126 VisitEvent::EndTraverseFunction)); 127 } 128 129 TEST(RecursiveASTVisitorTest, EnumDeclWithBase) { 130 // Check enum and its integer base is visited. 131 llvm::StringRef Code = R"cpp( 132 typedef int Foo; 133 enum Bar : Foo; 134 )cpp"; 135 136 EXPECT_THAT(collectEvents(Code), 137 ElementsAre(VisitEvent::StartTraverseEnum, 138 VisitEvent::StartTraverseTypedefType, 139 VisitEvent::EndTraverseTypedefType, 140 VisitEvent::EndTraverseEnum)); 141 } 142