1 //===- unittest/AST/RecursiveASTVisitorTest.cpp ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "clang/AST/RecursiveASTVisitor.h" 10 #include "clang/AST/ASTConsumer.h" 11 #include "clang/AST/ASTContext.h" 12 #include "clang/AST/Attr.h" 13 #include "clang/Frontend/FrontendAction.h" 14 #include "clang/Tooling/Tooling.h" 15 #include "llvm/ADT/FunctionExtras.h" 16 #include "llvm/ADT/STLExtras.h" 17 #include "gmock/gmock.h" 18 #include "gtest/gtest.h" 19 #include <cassert> 20 21 using namespace clang; 22 using ::testing::ElementsAre; 23 24 namespace { 25 class ProcessASTAction : public clang::ASTFrontendAction { 26 public: 27 ProcessASTAction(llvm::unique_function<void(clang::ASTContext &)> Process) 28 : Process(std::move(Process)) { 29 assert(this->Process); 30 } 31 32 std::unique_ptr<ASTConsumer> CreateASTConsumer(CompilerInstance &CI, 33 StringRef InFile) { 34 class Consumer : public ASTConsumer { 35 public: 36 Consumer(llvm::function_ref<void(ASTContext &CTx)> Process) 37 : Process(Process) {} 38 39 void HandleTranslationUnit(ASTContext &Ctx) override { Process(Ctx); } 40 41 private: 42 llvm::function_ref<void(ASTContext &CTx)> Process; 43 }; 44 45 return std::make_unique<Consumer>(Process); 46 } 47 48 private: 49 llvm::unique_function<void(clang::ASTContext &)> Process; 50 }; 51 52 enum class VisitEvent { 53 StartTraverseFunction, 54 EndTraverseFunction, 55 StartTraverseAttr, 56 EndTraverseAttr 57 }; 58 59 class CollectInterestingEvents 60 : public RecursiveASTVisitor<CollectInterestingEvents> { 61 public: 62 bool TraverseFunctionDecl(FunctionDecl *D) { 63 Events.push_back(VisitEvent::StartTraverseFunction); 64 bool Ret = RecursiveASTVisitor::TraverseFunctionDecl(D); 65 Events.push_back(VisitEvent::EndTraverseFunction); 66 67 return Ret; 68 } 69 70 bool TraverseAttr(Attr *A) { 71 Events.push_back(VisitEvent::StartTraverseAttr); 72 bool Ret = RecursiveASTVisitor::TraverseAttr(A); 73 Events.push_back(VisitEvent::EndTraverseAttr); 74 75 return Ret; 76 } 77 78 std::vector<VisitEvent> takeEvents() && { return std::move(Events); } 79 80 private: 81 std::vector<VisitEvent> Events; 82 }; 83 84 std::vector<VisitEvent> collectEvents(llvm::StringRef Code) { 85 CollectInterestingEvents Visitor; 86 clang::tooling::runToolOnCode( 87 new ProcessASTAction( 88 [&](clang::ASTContext &Ctx) { Visitor.TraverseAST(Ctx); }), 89 Code); 90 return std::move(Visitor).takeEvents(); 91 } 92 } // namespace 93 94 TEST(RecursiveASTVisitorTest, AttributesInsideDecls) { 95 /// Check attributes are traversed inside TraverseFunctionDecl. 96 llvm::StringRef Code = R"cpp( 97 __attribute__((annotate("something"))) int foo() { return 10; } 98 )cpp"; 99 100 EXPECT_THAT(collectEvents(Code), 101 ElementsAre(VisitEvent::StartTraverseFunction, 102 VisitEvent::StartTraverseAttr, 103 VisitEvent::EndTraverseAttr, 104 VisitEvent::EndTraverseFunction)); 105 } 106