1 //===--- SmartPtrArrayMismatchCheck.cpp - clang-tidy ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "SmartPtrArrayMismatchCheck.h"
10 #include "../utils/ASTUtils.h"
11 #include "clang/ASTMatchers/ASTMatchFinder.h"
12 #include "clang/Lex/Lexer.h"
13 
14 using namespace clang::ast_matchers;
15 
16 namespace clang {
17 namespace tidy {
18 namespace bugprone {
19 
20 namespace {
21 
22 constexpr char ConstructExprN[] = "found_construct_expr";
23 constexpr char NewExprN[] = "found_new_expr";
24 constexpr char ConstructorN[] = "found_constructor";
25 
isInSingleDeclStmt(const DeclaratorDecl * D)26 bool isInSingleDeclStmt(const DeclaratorDecl *D) {
27   const DynTypedNodeList Parents =
28       D->getASTContext().getParentMapContext().getParents(*D);
29   for (const DynTypedNode &PNode : Parents)
30     if (const auto *PDecl = PNode.get<DeclStmt>())
31       return PDecl->isSingleDecl();
32   return false;
33 }
34 
getConstructedVarOrField(const Expr * FoundConstructExpr,ASTContext & Ctx)35 const DeclaratorDecl *getConstructedVarOrField(const Expr *FoundConstructExpr,
36                                                ASTContext &Ctx) {
37   const DynTypedNodeList ConstructParents =
38       Ctx.getParentMapContext().getParents(*FoundConstructExpr);
39   if (ConstructParents.size() != 1)
40     return nullptr;
41   const auto *ParentDecl = ConstructParents.begin()->get<DeclaratorDecl>();
42   if (isa_and_nonnull<VarDecl, FieldDecl>(ParentDecl))
43     return ParentDecl;
44 
45   return nullptr;
46 }
47 
48 } // namespace
49 
50 const char SmartPtrArrayMismatchCheck::PointerTypeN[] = "pointer_type";
51 
SmartPtrArrayMismatchCheck(StringRef Name,ClangTidyContext * Context,StringRef SmartPointerName)52 SmartPtrArrayMismatchCheck::SmartPtrArrayMismatchCheck(
53     StringRef Name, ClangTidyContext *Context, StringRef SmartPointerName)
54     : ClangTidyCheck(Name, Context), SmartPointerName(SmartPointerName) {}
55 
storeOptions(ClangTidyOptions::OptionMap & Opts)56 void SmartPtrArrayMismatchCheck::storeOptions(
57     ClangTidyOptions::OptionMap &Opts) {}
58 
registerMatchers(MatchFinder * Finder)59 void SmartPtrArrayMismatchCheck::registerMatchers(MatchFinder *Finder) {
60   // For both shared and unique pointers, we need to find constructor with
61   // exactly one parameter that has the pointer type. Other constructors are
62   // not applicable for this check.
63   auto FindConstructor =
64       cxxConstructorDecl(ofClass(getSmartPointerClassMatcher()),
65                          parameterCountIs(1), isExplicit())
66           .bind(ConstructorN);
67   auto FindConstructExpr =
68       cxxConstructExpr(
69           hasDeclaration(FindConstructor), argumentCountIs(1),
70           hasArgument(
71               0, cxxNewExpr(isArray(), hasType(pointerType(pointee(
72                                            equalsBoundNode(PointerTypeN)))))
73                      .bind(NewExprN)))
74           .bind(ConstructExprN);
75   Finder->addMatcher(FindConstructExpr, this);
76 }
77 
check(const MatchFinder::MatchResult & Result)78 void SmartPtrArrayMismatchCheck::check(const MatchFinder::MatchResult &Result) {
79   const auto *FoundNewExpr = Result.Nodes.getNodeAs<CXXNewExpr>(NewExprN);
80   const auto *FoundConstructExpr =
81       Result.Nodes.getNodeAs<CXXConstructExpr>(ConstructExprN);
82   const auto *FoundConstructorDecl =
83       Result.Nodes.getNodeAs<CXXConstructorDecl>(ConstructorN);
84 
85   ASTContext &Ctx = FoundConstructorDecl->getASTContext();
86   const DeclaratorDecl *VarOrField =
87       getConstructedVarOrField(FoundConstructExpr, Ctx);
88 
89   auto D = diag(FoundNewExpr->getBeginLoc(),
90                 "%0 pointer to non-array is initialized with array")
91            << SmartPointerName;
92   D << FoundNewExpr->getSourceRange();
93 
94   if (VarOrField) {
95     auto TSTypeLoc = VarOrField->getTypeSourceInfo()
96                          ->getTypeLoc()
97                          .getAsAdjusted<clang::TemplateSpecializationTypeLoc>();
98     assert(TSTypeLoc.getNumArgs() >= 1 &&
99            "Matched type should have at least 1 template argument.");
100 
101     SourceRange TemplateArgumentRange = TSTypeLoc.getArgLoc(0)
102                                             .getTypeSourceInfo()
103                                             ->getTypeLoc()
104                                             .getLocalSourceRange();
105     D << TemplateArgumentRange;
106 
107     if (isInSingleDeclStmt(VarOrField)) {
108       const SourceManager &SM = Ctx.getSourceManager();
109       if (!utils::rangeCanBeFixed(TemplateArgumentRange, &SM))
110         return;
111 
112       SourceLocation InsertLoc = Lexer::getLocForEndOfToken(
113           TemplateArgumentRange.getEnd(), 0, SM, Ctx.getLangOpts());
114       D << FixItHint::CreateInsertion(InsertLoc, "[]");
115     }
116   }
117 }
118 
119 } // namespace bugprone
120 } // namespace tidy
121 } // namespace clang
122