1 //===-- SimplifyBooleanExprCheck.cpp - clang-tidy -------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "SimplifyBooleanExprCheck.h"
10 #include "clang/AST/RecursiveASTVisitor.h"
11 #include "clang/Lex/Lexer.h"
12 #include "llvm/Support/SaveAndRestore.h"
13 
14 #include <string>
15 #include <utility>
16 
17 using namespace clang::ast_matchers;
18 
19 namespace clang {
20 namespace tidy {
21 namespace readability {
22 
23 namespace {
24 
getText(const ASTContext & Context,SourceRange Range)25 StringRef getText(const ASTContext &Context, SourceRange Range) {
26   return Lexer::getSourceText(CharSourceRange::getTokenRange(Range),
27                               Context.getSourceManager(),
28                               Context.getLangOpts());
29 }
30 
getText(const ASTContext & Context,T & Node)31 template <typename T> StringRef getText(const ASTContext &Context, T &Node) {
32   return getText(Context, Node.getSourceRange());
33 }
34 
35 } // namespace
36 
37 static constexpr char SimplifyOperatorDiagnostic[] =
38     "redundant boolean literal supplied to boolean operator";
39 static constexpr char SimplifyConditionDiagnostic[] =
40     "redundant boolean literal in if statement condition";
41 static constexpr char SimplifyConditionalReturnDiagnostic[] =
42     "redundant boolean literal in conditional return statement";
43 
needsParensAfterUnaryNegation(const Expr * E)44 static bool needsParensAfterUnaryNegation(const Expr *E) {
45   E = E->IgnoreImpCasts();
46   if (isa<BinaryOperator>(E) || isa<ConditionalOperator>(E))
47     return true;
48 
49   if (const auto *Op = dyn_cast<CXXOperatorCallExpr>(E))
50     return Op->getNumArgs() == 2 && Op->getOperator() != OO_Call &&
51            Op->getOperator() != OO_Subscript;
52 
53   return false;
54 }
55 
56 static std::pair<BinaryOperatorKind, BinaryOperatorKind> Opposites[] = {
57     {BO_LT, BO_GE}, {BO_GT, BO_LE}, {BO_EQ, BO_NE}};
58 
negatedOperator(const BinaryOperator * BinOp)59 static StringRef negatedOperator(const BinaryOperator *BinOp) {
60   const BinaryOperatorKind Opcode = BinOp->getOpcode();
61   for (auto NegatableOp : Opposites) {
62     if (Opcode == NegatableOp.first)
63       return BinOp->getOpcodeStr(NegatableOp.second);
64     if (Opcode == NegatableOp.second)
65       return BinOp->getOpcodeStr(NegatableOp.first);
66   }
67   return {};
68 }
69 
70 static std::pair<OverloadedOperatorKind, StringRef> OperatorNames[] = {
71     {OO_EqualEqual, "=="},   {OO_ExclaimEqual, "!="}, {OO_Less, "<"},
72     {OO_GreaterEqual, ">="}, {OO_Greater, ">"},       {OO_LessEqual, "<="}};
73 
getOperatorName(OverloadedOperatorKind OpKind)74 static StringRef getOperatorName(OverloadedOperatorKind OpKind) {
75   for (auto Name : OperatorNames) {
76     if (Name.first == OpKind)
77       return Name.second;
78   }
79 
80   return {};
81 }
82 
83 static std::pair<OverloadedOperatorKind, OverloadedOperatorKind>
84     OppositeOverloads[] = {{OO_EqualEqual, OO_ExclaimEqual},
85                            {OO_Less, OO_GreaterEqual},
86                            {OO_Greater, OO_LessEqual}};
87 
negatedOperator(const CXXOperatorCallExpr * OpCall)88 static StringRef negatedOperator(const CXXOperatorCallExpr *OpCall) {
89   const OverloadedOperatorKind Opcode = OpCall->getOperator();
90   for (auto NegatableOp : OppositeOverloads) {
91     if (Opcode == NegatableOp.first)
92       return getOperatorName(NegatableOp.second);
93     if (Opcode == NegatableOp.second)
94       return getOperatorName(NegatableOp.first);
95   }
96   return {};
97 }
98 
asBool(StringRef Text,bool NeedsStaticCast)99 static std::string asBool(StringRef Text, bool NeedsStaticCast) {
100   if (NeedsStaticCast)
101     return ("static_cast<bool>(" + Text + ")").str();
102 
103   return std::string(Text);
104 }
105 
needsNullPtrComparison(const Expr * E)106 static bool needsNullPtrComparison(const Expr *E) {
107   if (const auto *ImpCast = dyn_cast<ImplicitCastExpr>(E))
108     return ImpCast->getCastKind() == CK_PointerToBoolean ||
109            ImpCast->getCastKind() == CK_MemberPointerToBoolean;
110 
111   return false;
112 }
113 
needsZeroComparison(const Expr * E)114 static bool needsZeroComparison(const Expr *E) {
115   if (const auto *ImpCast = dyn_cast<ImplicitCastExpr>(E))
116     return ImpCast->getCastKind() == CK_IntegralToBoolean;
117 
118   return false;
119 }
120 
needsStaticCast(const Expr * E)121 static bool needsStaticCast(const Expr *E) {
122   if (const auto *ImpCast = dyn_cast<ImplicitCastExpr>(E)) {
123     if (ImpCast->getCastKind() == CK_UserDefinedConversion &&
124         ImpCast->getSubExpr()->getType()->isBooleanType()) {
125       if (const auto *MemCall =
126               dyn_cast<CXXMemberCallExpr>(ImpCast->getSubExpr())) {
127         if (const auto *MemDecl =
128                 dyn_cast<CXXConversionDecl>(MemCall->getMethodDecl())) {
129           if (MemDecl->isExplicit())
130             return true;
131         }
132       }
133     }
134   }
135 
136   E = E->IgnoreImpCasts();
137   return !E->getType()->isBooleanType();
138 }
139 
compareExpressionToConstant(const ASTContext & Context,const Expr * E,bool Negated,const char * Constant)140 static std::string compareExpressionToConstant(const ASTContext &Context,
141                                                const Expr *E, bool Negated,
142                                                const char *Constant) {
143   E = E->IgnoreImpCasts();
144   const std::string ExprText =
145       (isa<BinaryOperator>(E) ? ("(" + getText(Context, *E) + ")")
146                               : getText(Context, *E))
147           .str();
148   return ExprText + " " + (Negated ? "!=" : "==") + " " + Constant;
149 }
150 
compareExpressionToNullPtr(const ASTContext & Context,const Expr * E,bool Negated)151 static std::string compareExpressionToNullPtr(const ASTContext &Context,
152                                               const Expr *E, bool Negated) {
153   const char *NullPtr = Context.getLangOpts().CPlusPlus11 ? "nullptr" : "NULL";
154   return compareExpressionToConstant(Context, E, Negated, NullPtr);
155 }
156 
compareExpressionToZero(const ASTContext & Context,const Expr * E,bool Negated)157 static std::string compareExpressionToZero(const ASTContext &Context,
158                                            const Expr *E, bool Negated) {
159   return compareExpressionToConstant(Context, E, Negated, "0");
160 }
161 
replacementExpression(const ASTContext & Context,bool Negated,const Expr * E)162 static std::string replacementExpression(const ASTContext &Context,
163                                          bool Negated, const Expr *E) {
164   E = E->IgnoreParenBaseCasts();
165   if (const auto *EC = dyn_cast<ExprWithCleanups>(E))
166     E = EC->getSubExpr();
167 
168   const bool NeedsStaticCast = needsStaticCast(E);
169   if (Negated) {
170     if (const auto *UnOp = dyn_cast<UnaryOperator>(E)) {
171       if (UnOp->getOpcode() == UO_LNot) {
172         if (needsNullPtrComparison(UnOp->getSubExpr()))
173           return compareExpressionToNullPtr(Context, UnOp->getSubExpr(), true);
174 
175         if (needsZeroComparison(UnOp->getSubExpr()))
176           return compareExpressionToZero(Context, UnOp->getSubExpr(), true);
177 
178         return replacementExpression(Context, false, UnOp->getSubExpr());
179       }
180     }
181 
182     if (needsNullPtrComparison(E))
183       return compareExpressionToNullPtr(Context, E, false);
184 
185     if (needsZeroComparison(E))
186       return compareExpressionToZero(Context, E, false);
187 
188     StringRef NegatedOperator;
189     const Expr *LHS = nullptr;
190     const Expr *RHS = nullptr;
191     if (const auto *BinOp = dyn_cast<BinaryOperator>(E)) {
192       NegatedOperator = negatedOperator(BinOp);
193       LHS = BinOp->getLHS();
194       RHS = BinOp->getRHS();
195     } else if (const auto *OpExpr = dyn_cast<CXXOperatorCallExpr>(E)) {
196       if (OpExpr->getNumArgs() == 2) {
197         NegatedOperator = negatedOperator(OpExpr);
198         LHS = OpExpr->getArg(0);
199         RHS = OpExpr->getArg(1);
200       }
201     }
202     if (!NegatedOperator.empty() && LHS && RHS)
203       return (asBool((getText(Context, *LHS) + " " + NegatedOperator + " " +
204                       getText(Context, *RHS))
205                          .str(),
206                      NeedsStaticCast));
207 
208     StringRef Text = getText(Context, *E);
209     if (!NeedsStaticCast && needsParensAfterUnaryNegation(E))
210       return ("!(" + Text + ")").str();
211 
212     if (needsNullPtrComparison(E))
213       return compareExpressionToNullPtr(Context, E, false);
214 
215     if (needsZeroComparison(E))
216       return compareExpressionToZero(Context, E, false);
217 
218     return ("!" + asBool(Text, NeedsStaticCast));
219   }
220 
221   if (const auto *UnOp = dyn_cast<UnaryOperator>(E)) {
222     if (UnOp->getOpcode() == UO_LNot) {
223       if (needsNullPtrComparison(UnOp->getSubExpr()))
224         return compareExpressionToNullPtr(Context, UnOp->getSubExpr(), false);
225 
226       if (needsZeroComparison(UnOp->getSubExpr()))
227         return compareExpressionToZero(Context, UnOp->getSubExpr(), false);
228     }
229   }
230 
231   if (needsNullPtrComparison(E))
232     return compareExpressionToNullPtr(Context, E, true);
233 
234   if (needsZeroComparison(E))
235     return compareExpressionToZero(Context, E, true);
236 
237   return asBool(getText(Context, *E), NeedsStaticCast);
238 }
239 
containsDiscardedTokens(const ASTContext & Context,CharSourceRange CharRange)240 static bool containsDiscardedTokens(const ASTContext &Context,
241                                     CharSourceRange CharRange) {
242   std::string ReplacementText =
243       Lexer::getSourceText(CharRange, Context.getSourceManager(),
244                            Context.getLangOpts())
245           .str();
246   Lexer Lex(CharRange.getBegin(), Context.getLangOpts(), ReplacementText.data(),
247             ReplacementText.data(),
248             ReplacementText.data() + ReplacementText.size());
249   Lex.SetCommentRetentionState(true);
250 
251   Token Tok;
252   while (!Lex.LexFromRawLexer(Tok)) {
253     if (Tok.is(tok::TokenKind::comment) || Tok.is(tok::TokenKind::hash))
254       return true;
255   }
256 
257   return false;
258 }
259 
260 class SimplifyBooleanExprCheck::Visitor : public RecursiveASTVisitor<Visitor> {
261   using Base = RecursiveASTVisitor<Visitor>;
262 
263 public:
Visitor(SimplifyBooleanExprCheck * Check,ASTContext & Context)264   Visitor(SimplifyBooleanExprCheck *Check, ASTContext &Context)
265       : Check(Check), Context(Context) {}
266 
traverse()267   bool traverse() { return TraverseAST(Context); }
268 
shouldIgnore(Stmt * S)269   static bool shouldIgnore(Stmt *S) {
270     switch (S->getStmtClass()) {
271     case Stmt::ImplicitCastExprClass:
272     case Stmt::MaterializeTemporaryExprClass:
273     case Stmt::CXXBindTemporaryExprClass:
274       return true;
275     default:
276       return false;
277     }
278   }
279 
dataTraverseStmtPre(Stmt * S)280   bool dataTraverseStmtPre(Stmt *S) {
281     if (S && !shouldIgnore(S))
282       StmtStack.push_back(S);
283     return true;
284   }
285 
dataTraverseStmtPost(Stmt * S)286   bool dataTraverseStmtPost(Stmt *S) {
287     if (S && !shouldIgnore(S)) {
288       assert(StmtStack.back() == S);
289       StmtStack.pop_back();
290     }
291     return true;
292   }
293 
VisitBinaryOperator(const BinaryOperator * Op) const294   bool VisitBinaryOperator(const BinaryOperator *Op) const {
295     Check->reportBinOp(Context, Op);
296     return true;
297   }
298 
299   // Extracts a bool if an expression is (true|false|!true|!false);
getAsBoolLiteral(const Expr * E,bool FilterMacro)300   static Optional<bool> getAsBoolLiteral(const Expr *E, bool FilterMacro) {
301     if (const auto *Bool = dyn_cast<CXXBoolLiteralExpr>(E)) {
302       if (FilterMacro && Bool->getBeginLoc().isMacroID())
303         return llvm::None;
304       return Bool->getValue();
305     }
306     if (const auto *UnaryOp = dyn_cast<UnaryOperator>(E)) {
307       if (FilterMacro && UnaryOp->getBeginLoc().isMacroID())
308         return None;
309       if (UnaryOp->getOpcode() == UO_LNot)
310         if (Optional<bool> Res = getAsBoolLiteral(
311                 UnaryOp->getSubExpr()->IgnoreImplicit(), FilterMacro))
312           return !*Res;
313     }
314     return llvm::None;
315   }
316 
317   template <typename Node> struct NodeAndBool {
318     const Node *Item = nullptr;
319     bool Bool = false;
320 
operator boolclang::tidy::readability::SimplifyBooleanExprCheck::Visitor::NodeAndBool321     operator bool() const { return Item != nullptr; }
322   };
323 
324   using ExprAndBool = NodeAndBool<Expr>;
325   using DeclAndBool = NodeAndBool<Decl>;
326 
327   /// Detect's return (true|false|!true|!false);
parseReturnLiteralBool(const Stmt * S)328   static ExprAndBool parseReturnLiteralBool(const Stmt *S) {
329     const auto *RS = dyn_cast<ReturnStmt>(S);
330     if (!RS || !RS->getRetValue())
331       return {};
332     if (Optional<bool> Ret =
333             getAsBoolLiteral(RS->getRetValue()->IgnoreImplicit(), false)) {
334       return {RS->getRetValue(), *Ret};
335     }
336     return {};
337   }
338 
339   /// If \p S is not a \c CompoundStmt, applies F on \p S, otherwise if there is
340   /// only 1 statement in the \c CompoundStmt, applies F on that single
341   /// statement.
342   template <typename Functor>
checkSingleStatement(Stmt * S,Functor F)343   static auto checkSingleStatement(Stmt *S, Functor F) -> decltype(F(S)) {
344     if (auto *CS = dyn_cast<CompoundStmt>(S)) {
345       if (CS->size() == 1)
346         return F(CS->body_front());
347       return {};
348     }
349     return F(S);
350   }
351 
parent() const352   Stmt *parent() const {
353     return StmtStack.size() < 2 ? nullptr : StmtStack[StmtStack.size() - 2];
354   }
355 
VisitIfStmt(IfStmt * If)356   bool VisitIfStmt(IfStmt *If) {
357     // Skip any if's that have a condition var or an init statement.
358     if (If->hasInitStorage() || If->hasVarStorage())
359       return true;
360     /*
361      * if (true) ThenStmt(); -> ThenStmt();
362      * if (false) ThenStmt(); -> <Empty>;
363      * if (false) ThenStmt(); else ElseStmt() -> ElseStmt();
364      */
365     Expr *Cond = If->getCond()->IgnoreImplicit();
366     if (Optional<bool> Bool = getAsBoolLiteral(Cond, true)) {
367       if (*Bool)
368         Check->replaceWithThenStatement(Context, If, Cond);
369       else
370         Check->replaceWithElseStatement(Context, If, Cond);
371     }
372 
373     if (If->getElse()) {
374       /*
375        * if (Cond) return true; else return false; -> return Cond;
376        * if (Cond) return false; else return true; -> return !Cond;
377        */
378       if (ExprAndBool ThenReturnBool =
379               checkSingleStatement(If->getThen(), parseReturnLiteralBool)) {
380         ExprAndBool ElseReturnBool =
381             checkSingleStatement(If->getElse(), parseReturnLiteralBool);
382         if (ElseReturnBool && ThenReturnBool.Bool != ElseReturnBool.Bool) {
383           if (Check->ChainedConditionalReturn ||
384               !isa_and_nonnull<IfStmt>(parent())) {
385             Check->replaceWithReturnCondition(Context, If, ThenReturnBool.Item,
386                                               ElseReturnBool.Bool);
387           }
388         }
389       } else {
390         /*
391          * if (Cond) A = true; else A = false; -> A = Cond;
392          * if (Cond) A = false; else A = true; -> A = !Cond;
393          */
394         Expr *Var = nullptr;
395         SourceLocation Loc;
396         auto VarBoolAssignmentMatcher = [&Var,
397                                          &Loc](const Stmt *S) -> DeclAndBool {
398           const auto *BO = dyn_cast<BinaryOperator>(S);
399           if (!BO || BO->getOpcode() != BO_Assign)
400             return {};
401           Optional<bool> RightasBool =
402               getAsBoolLiteral(BO->getRHS()->IgnoreImplicit(), false);
403           if (!RightasBool)
404             return {};
405           Expr *IgnImp = BO->getLHS()->IgnoreImplicit();
406           if (!Var) {
407             // We only need to track these for the Then branch.
408             Loc = BO->getRHS()->getBeginLoc();
409             Var = IgnImp;
410           }
411           if (auto *DRE = dyn_cast<DeclRefExpr>(IgnImp))
412             return {DRE->getDecl(), *RightasBool};
413           if (auto *ME = dyn_cast<MemberExpr>(IgnImp))
414             return {ME->getMemberDecl(), *RightasBool};
415           return {};
416         };
417         if (DeclAndBool ThenAssignment =
418                 checkSingleStatement(If->getThen(), VarBoolAssignmentMatcher)) {
419           DeclAndBool ElseAssignment =
420               checkSingleStatement(If->getElse(), VarBoolAssignmentMatcher);
421           if (ElseAssignment.Item == ThenAssignment.Item &&
422               ElseAssignment.Bool != ThenAssignment.Bool) {
423             if (Check->ChainedConditionalAssignment ||
424                 !isa_and_nonnull<IfStmt>(parent())) {
425               Check->replaceWithAssignment(Context, If, Var, Loc,
426                                            ElseAssignment.Bool);
427             }
428           }
429         }
430       }
431     }
432     return true;
433   }
434 
VisitConditionalOperator(ConditionalOperator * Cond)435   bool VisitConditionalOperator(ConditionalOperator *Cond) {
436     /*
437      * Condition ? true : false; -> Condition
438      * Condition ? false : true; -> !Condition;
439      */
440     if (Optional<bool> Then =
441             getAsBoolLiteral(Cond->getTrueExpr()->IgnoreImplicit(), false)) {
442       if (Optional<bool> Else =
443               getAsBoolLiteral(Cond->getFalseExpr()->IgnoreImplicit(), false)) {
444         if (*Then != *Else)
445           Check->replaceWithCondition(Context, Cond, *Else);
446       }
447     }
448     return true;
449   }
450 
VisitCompoundStmt(CompoundStmt * CS)451   bool VisitCompoundStmt(CompoundStmt *CS) {
452     if (CS->size() < 2)
453       return true;
454     bool CurIf = false, PrevIf = false;
455     for (auto First = CS->body_begin(), Second = std::next(First),
456               End = CS->body_end();
457          Second != End; ++Second, ++First) {
458       PrevIf = CurIf;
459       CurIf = isa<IfStmt>(*First);
460       ExprAndBool TrailingReturnBool = parseReturnLiteralBool(*Second);
461       if (!TrailingReturnBool)
462         continue;
463 
464       if (CurIf) {
465         /*
466          * if (Cond) return true; return false; -> return Cond;
467          * if (Cond) return false; return true; -> return !Cond;
468          */
469         auto *If = cast<IfStmt>(*First);
470         if (!If->hasInitStorage() && !If->hasVarStorage()) {
471           ExprAndBool ThenReturnBool =
472               checkSingleStatement(If->getThen(), parseReturnLiteralBool);
473           if (ThenReturnBool &&
474               ThenReturnBool.Bool != TrailingReturnBool.Bool) {
475             if ((Check->ChainedConditionalReturn || !PrevIf) &&
476                 If->getElse() == nullptr) {
477               Check->replaceCompoundReturnWithCondition(
478                   Context, cast<ReturnStmt>(*Second), TrailingReturnBool.Bool,
479                   If, ThenReturnBool.Item);
480             }
481           }
482         }
483       } else if (isa<LabelStmt, CaseStmt, DefaultStmt>(*First)) {
484         /*
485          * (case X|label_X|default): if (Cond) return BoolLiteral;
486          *                           return !BoolLiteral
487          */
488         Stmt *SubStmt =
489             isa<LabelStmt>(*First)  ? cast<LabelStmt>(*First)->getSubStmt()
490             : isa<CaseStmt>(*First) ? cast<CaseStmt>(*First)->getSubStmt()
491                                     : cast<DefaultStmt>(*First)->getSubStmt();
492         auto *SubIf = dyn_cast<IfStmt>(SubStmt);
493         if (SubIf && !SubIf->getElse() && !SubIf->hasInitStorage() &&
494             !SubIf->hasVarStorage()) {
495           ExprAndBool ThenReturnBool =
496               checkSingleStatement(SubIf->getThen(), parseReturnLiteralBool);
497           if (ThenReturnBool &&
498               ThenReturnBool.Bool != TrailingReturnBool.Bool) {
499             Check->replaceCompoundReturnWithCondition(
500                 Context, cast<ReturnStmt>(*Second), TrailingReturnBool.Bool,
501                 SubIf, ThenReturnBool.Item);
502           }
503         }
504       }
505     }
506     return true;
507   }
508 
isUnaryLNot(const Expr * E)509   static bool isUnaryLNot(const Expr *E) {
510     return isa<UnaryOperator>(E) &&
511            cast<UnaryOperator>(E)->getOpcode() == UO_LNot;
512   }
513 
514   template <typename Functor>
checkEitherSide(const BinaryOperator * BO,Functor Func)515   static bool checkEitherSide(const BinaryOperator *BO, Functor Func) {
516     return Func(BO->getLHS()) || Func(BO->getRHS());
517   }
518 
nestedDemorgan(const Expr * E,unsigned NestingLevel)519   static bool nestedDemorgan(const Expr *E, unsigned NestingLevel) {
520     const auto *BO = dyn_cast<BinaryOperator>(E->IgnoreUnlessSpelledInSource());
521     if (!BO)
522       return false;
523     if (!BO->getType()->isBooleanType())
524       return false;
525     switch (BO->getOpcode()) {
526     case BO_LT:
527     case BO_GT:
528     case BO_LE:
529     case BO_GE:
530     case BO_EQ:
531     case BO_NE:
532       return true;
533     case BO_LAnd:
534     case BO_LOr:
535       if (checkEitherSide(BO, isUnaryLNot))
536         return true;
537       if (NestingLevel) {
538         if (checkEitherSide(BO, [NestingLevel](const Expr *E) {
539               return nestedDemorgan(E, NestingLevel - 1);
540             }))
541           return true;
542       }
543       return false;
544     default:
545       return false;
546     }
547   }
548 
TraverseUnaryOperator(UnaryOperator * Op)549   bool TraverseUnaryOperator(UnaryOperator *Op) {
550     if (!Check->SimplifyDeMorgan || Op->getOpcode() != UO_LNot)
551       return Base::TraverseUnaryOperator(Op);
552     Expr *SubImp = Op->getSubExpr()->IgnoreImplicit();
553     auto *Parens = dyn_cast<ParenExpr>(SubImp);
554     auto *BinaryOp =
555         Parens
556             ? dyn_cast<BinaryOperator>(Parens->getSubExpr()->IgnoreImplicit())
557             : dyn_cast<BinaryOperator>(SubImp);
558     if (!BinaryOp || !BinaryOp->isLogicalOp() ||
559         !BinaryOp->getType()->isBooleanType())
560       return Base::TraverseUnaryOperator(Op);
561     if (Check->SimplifyDeMorganRelaxed ||
562         checkEitherSide(BinaryOp, isUnaryLNot) ||
563         checkEitherSide(BinaryOp,
564                         [](const Expr *E) { return nestedDemorgan(E, 1); })) {
565       if (Check->reportDeMorgan(Context, Op, BinaryOp, !IsProcessing, parent(),
566                                 Parens) &&
567           !Check->areDiagsSelfContained()) {
568         llvm::SaveAndRestore<bool> RAII(IsProcessing, true);
569         return Base::TraverseUnaryOperator(Op);
570       }
571     }
572     return Base::TraverseUnaryOperator(Op);
573   }
574 
575 private:
576   bool IsProcessing = false;
577   SimplifyBooleanExprCheck *Check;
578   SmallVector<Stmt *, 32> StmtStack;
579   ASTContext &Context;
580 };
581 
SimplifyBooleanExprCheck(StringRef Name,ClangTidyContext * Context)582 SimplifyBooleanExprCheck::SimplifyBooleanExprCheck(StringRef Name,
583                                                    ClangTidyContext *Context)
584     : ClangTidyCheck(Name, Context),
585       ChainedConditionalReturn(Options.get("ChainedConditionalReturn", false)),
586       ChainedConditionalAssignment(
587           Options.get("ChainedConditionalAssignment", false)),
588       SimplifyDeMorgan(Options.get("SimplifyDeMorgan", true)),
589       SimplifyDeMorganRelaxed(Options.get("SimplifyDeMorganRelaxed", false)) {
590   if (SimplifyDeMorganRelaxed && !SimplifyDeMorgan)
591     configurationDiag("%0: 'SimplifyDeMorganRelaxed' cannot be enabled "
592                       "without 'SimplifyDeMorgan' enabled")
593         << Name;
594 }
595 
containsBoolLiteral(const Expr * E)596 static bool containsBoolLiteral(const Expr *E) {
597   if (!E)
598     return false;
599   E = E->IgnoreParenImpCasts();
600   if (isa<CXXBoolLiteralExpr>(E))
601     return true;
602   if (const auto *BinOp = dyn_cast<BinaryOperator>(E))
603     return containsBoolLiteral(BinOp->getLHS()) ||
604            containsBoolLiteral(BinOp->getRHS());
605   if (const auto *UnaryOp = dyn_cast<UnaryOperator>(E))
606     return containsBoolLiteral(UnaryOp->getSubExpr());
607   return false;
608 }
609 
reportBinOp(const ASTContext & Context,const BinaryOperator * Op)610 void SimplifyBooleanExprCheck::reportBinOp(const ASTContext &Context,
611                                            const BinaryOperator *Op) {
612   const auto *LHS = Op->getLHS()->IgnoreParenImpCasts();
613   const auto *RHS = Op->getRHS()->IgnoreParenImpCasts();
614 
615   const CXXBoolLiteralExpr *Bool;
616   const Expr *Other;
617   if ((Bool = dyn_cast<CXXBoolLiteralExpr>(LHS)) != nullptr)
618     Other = RHS;
619   else if ((Bool = dyn_cast<CXXBoolLiteralExpr>(RHS)) != nullptr)
620     Other = LHS;
621   else
622     return;
623 
624   if (Bool->getBeginLoc().isMacroID())
625     return;
626 
627   // FIXME: why do we need this?
628   if (!isa<CXXBoolLiteralExpr>(Other) && containsBoolLiteral(Other))
629     return;
630 
631   bool BoolValue = Bool->getValue();
632 
633   auto ReplaceWithExpression = [this, &Context, LHS, RHS,
634                                 Bool](const Expr *ReplaceWith, bool Negated) {
635     std::string Replacement =
636         replacementExpression(Context, Negated, ReplaceWith);
637     SourceRange Range(LHS->getBeginLoc(), RHS->getEndLoc());
638     issueDiag(Context, Bool->getBeginLoc(), SimplifyOperatorDiagnostic, Range,
639               Replacement);
640   };
641 
642   switch (Op->getOpcode()) {
643   case BO_LAnd:
644     if (BoolValue)
645       // expr && true -> expr
646       ReplaceWithExpression(Other, /*Negated=*/false);
647     else
648       // expr && false -> false
649       ReplaceWithExpression(Bool, /*Negated=*/false);
650     break;
651   case BO_LOr:
652     if (BoolValue)
653       // expr || true -> true
654       ReplaceWithExpression(Bool, /*Negated=*/false);
655     else
656       // expr || false -> expr
657       ReplaceWithExpression(Other, /*Negated=*/false);
658     break;
659   case BO_EQ:
660     // expr == true -> expr, expr == false -> !expr
661     ReplaceWithExpression(Other, /*Negated=*/!BoolValue);
662     break;
663   case BO_NE:
664     // expr != true -> !expr, expr != false -> expr
665     ReplaceWithExpression(Other, /*Negated=*/BoolValue);
666     break;
667   default:
668     break;
669   }
670 }
671 
storeOptions(ClangTidyOptions::OptionMap & Opts)672 void SimplifyBooleanExprCheck::storeOptions(ClangTidyOptions::OptionMap &Opts) {
673   Options.store(Opts, "ChainedConditionalReturn", ChainedConditionalReturn);
674   Options.store(Opts, "ChainedConditionalAssignment",
675                 ChainedConditionalAssignment);
676   Options.store(Opts, "SimplifyDeMorgan", SimplifyDeMorgan);
677   Options.store(Opts, "SimplifyDeMorganRelaxed", SimplifyDeMorganRelaxed);
678 }
679 
registerMatchers(MatchFinder * Finder)680 void SimplifyBooleanExprCheck::registerMatchers(MatchFinder *Finder) {
681   Finder->addMatcher(translationUnitDecl(), this);
682 }
683 
check(const MatchFinder::MatchResult & Result)684 void SimplifyBooleanExprCheck::check(const MatchFinder::MatchResult &Result) {
685   Visitor(this, *Result.Context).traverse();
686 }
687 
issueDiag(const ASTContext & Context,SourceLocation Loc,StringRef Description,SourceRange ReplacementRange,StringRef Replacement)688 void SimplifyBooleanExprCheck::issueDiag(const ASTContext &Context,
689                                          SourceLocation Loc,
690                                          StringRef Description,
691                                          SourceRange ReplacementRange,
692                                          StringRef Replacement) {
693   CharSourceRange CharRange =
694       Lexer::makeFileCharRange(CharSourceRange::getTokenRange(ReplacementRange),
695                                Context.getSourceManager(), getLangOpts());
696 
697   DiagnosticBuilder Diag = diag(Loc, Description);
698   if (!containsDiscardedTokens(Context, CharRange))
699     Diag << FixItHint::CreateReplacement(CharRange, Replacement);
700 }
701 
replaceWithThenStatement(const ASTContext & Context,const IfStmt * IfStatement,const Expr * BoolLiteral)702 void SimplifyBooleanExprCheck::replaceWithThenStatement(
703     const ASTContext &Context, const IfStmt *IfStatement,
704     const Expr *BoolLiteral) {
705   issueDiag(Context, BoolLiteral->getBeginLoc(), SimplifyConditionDiagnostic,
706             IfStatement->getSourceRange(),
707             getText(Context, *IfStatement->getThen()));
708 }
709 
replaceWithElseStatement(const ASTContext & Context,const IfStmt * IfStatement,const Expr * BoolLiteral)710 void SimplifyBooleanExprCheck::replaceWithElseStatement(
711     const ASTContext &Context, const IfStmt *IfStatement,
712     const Expr *BoolLiteral) {
713   const Stmt *ElseStatement = IfStatement->getElse();
714   issueDiag(Context, BoolLiteral->getBeginLoc(), SimplifyConditionDiagnostic,
715             IfStatement->getSourceRange(),
716             ElseStatement ? getText(Context, *ElseStatement) : "");
717 }
718 
replaceWithCondition(const ASTContext & Context,const ConditionalOperator * Ternary,bool Negated)719 void SimplifyBooleanExprCheck::replaceWithCondition(
720     const ASTContext &Context, const ConditionalOperator *Ternary,
721     bool Negated) {
722   std::string Replacement =
723       replacementExpression(Context, Negated, Ternary->getCond());
724   issueDiag(Context, Ternary->getTrueExpr()->getBeginLoc(),
725             "redundant boolean literal in ternary expression result",
726             Ternary->getSourceRange(), Replacement);
727 }
728 
replaceWithReturnCondition(const ASTContext & Context,const IfStmt * If,const Expr * BoolLiteral,bool Negated)729 void SimplifyBooleanExprCheck::replaceWithReturnCondition(
730     const ASTContext &Context, const IfStmt *If, const Expr *BoolLiteral,
731     bool Negated) {
732   StringRef Terminator = isa<CompoundStmt>(If->getElse()) ? ";" : "";
733   std::string Condition =
734       replacementExpression(Context, Negated, If->getCond());
735   std::string Replacement = ("return " + Condition + Terminator).str();
736   SourceLocation Start = BoolLiteral->getBeginLoc();
737   issueDiag(Context, Start, SimplifyConditionalReturnDiagnostic,
738             If->getSourceRange(), Replacement);
739 }
740 
replaceCompoundReturnWithCondition(const ASTContext & Context,const ReturnStmt * Ret,bool Negated,const IfStmt * If,const Expr * ThenReturn)741 void SimplifyBooleanExprCheck::replaceCompoundReturnWithCondition(
742     const ASTContext &Context, const ReturnStmt *Ret, bool Negated,
743     const IfStmt *If, const Expr *ThenReturn) {
744   const std::string Replacement =
745       "return " + replacementExpression(Context, Negated, If->getCond());
746   issueDiag(Context, ThenReturn->getBeginLoc(),
747             SimplifyConditionalReturnDiagnostic,
748             SourceRange(If->getBeginLoc(), Ret->getEndLoc()), Replacement);
749 }
750 
replaceWithAssignment(const ASTContext & Context,const IfStmt * IfAssign,const Expr * Var,SourceLocation Loc,bool Negated)751 void SimplifyBooleanExprCheck::replaceWithAssignment(const ASTContext &Context,
752                                                      const IfStmt *IfAssign,
753                                                      const Expr *Var,
754                                                      SourceLocation Loc,
755                                                      bool Negated) {
756   SourceRange Range = IfAssign->getSourceRange();
757   StringRef VariableName = getText(Context, *Var);
758   StringRef Terminator = isa<CompoundStmt>(IfAssign->getElse()) ? ";" : "";
759   std::string Condition =
760       replacementExpression(Context, Negated, IfAssign->getCond());
761   std::string Replacement =
762       (VariableName + " = " + Condition + Terminator).str();
763   issueDiag(Context, Loc, "redundant boolean literal in conditional assignment",
764             Range, Replacement);
765 }
766 
767 /// Swaps a \c BinaryOperator opcode from `&&` to `||` or vice-versa.
flipDemorganOperator(llvm::SmallVectorImpl<FixItHint> & Output,const BinaryOperator * BO)768 static bool flipDemorganOperator(llvm::SmallVectorImpl<FixItHint> &Output,
769                                  const BinaryOperator *BO) {
770   assert(BO->isLogicalOp());
771   if (BO->getOperatorLoc().isMacroID())
772     return true;
773   Output.push_back(FixItHint::CreateReplacement(
774       BO->getOperatorLoc(), BO->getOpcode() == BO_LAnd ? "||" : "&&"));
775   return false;
776 }
777 
getDemorganFlippedOperator(BinaryOperatorKind BO)778 static BinaryOperatorKind getDemorganFlippedOperator(BinaryOperatorKind BO) {
779   assert(BinaryOperator::isLogicalOp(BO));
780   return BO == BO_LAnd ? BO_LOr : BO_LAnd;
781 }
782 
783 static bool flipDemorganSide(SmallVectorImpl<FixItHint> &Fixes,
784                              const ASTContext &Ctx, const Expr *E,
785                              Optional<BinaryOperatorKind> OuterBO);
786 
787 /// Inverts \p BinOp, Removing \p Parens if they exist and are safe to remove.
788 /// returns \c true if there is any issue building the Fixes, \c false
789 /// otherwise.
flipDemorganBinaryOperator(SmallVectorImpl<FixItHint> & Fixes,const ASTContext & Ctx,const BinaryOperator * BinOp,Optional<BinaryOperatorKind> OuterBO,const ParenExpr * Parens=nullptr)790 static bool flipDemorganBinaryOperator(SmallVectorImpl<FixItHint> &Fixes,
791                                        const ASTContext &Ctx,
792                                        const BinaryOperator *BinOp,
793                                        Optional<BinaryOperatorKind> OuterBO,
794                                        const ParenExpr *Parens = nullptr) {
795   switch (BinOp->getOpcode()) {
796   case BO_LAnd:
797   case BO_LOr: {
798     // if we have 'a && b' or 'a || b', use demorgan to flip it to '!a || !b'
799     // or '!a && !b'.
800     if (flipDemorganOperator(Fixes, BinOp))
801       return true;
802     auto NewOp = getDemorganFlippedOperator(BinOp->getOpcode());
803     if (OuterBO) {
804       // The inner parens are technically needed in a fix for
805       // `!(!A1 && !(A2 || A3)) -> (A1 || (A2 && A3))`,
806       // however this would trip the LogicalOpParentheses warning.
807       // FIXME: Make this user configurable or detect if that warning is
808       // enabled.
809       constexpr bool LogicalOpParentheses = true;
810       if (((*OuterBO == NewOp) || (!LogicalOpParentheses &&
811                                    (*OuterBO == BO_LOr && NewOp == BO_LAnd))) &&
812           Parens) {
813         if (!Parens->getLParen().isMacroID() &&
814             !Parens->getRParen().isMacroID()) {
815           Fixes.push_back(FixItHint::CreateRemoval(Parens->getLParen()));
816           Fixes.push_back(FixItHint::CreateRemoval(Parens->getRParen()));
817         }
818       }
819       if (*OuterBO == BO_LAnd && NewOp == BO_LOr && !Parens) {
820         Fixes.push_back(FixItHint::CreateInsertion(BinOp->getBeginLoc(), "("));
821         Fixes.push_back(FixItHint::CreateInsertion(
822             Lexer::getLocForEndOfToken(BinOp->getEndLoc(), 0,
823                                        Ctx.getSourceManager(),
824                                        Ctx.getLangOpts()),
825             ")"));
826       }
827     }
828     if (flipDemorganSide(Fixes, Ctx, BinOp->getLHS(), NewOp) ||
829         flipDemorganSide(Fixes, Ctx, BinOp->getRHS(), NewOp))
830       return true;
831     return false;
832   };
833   case BO_LT:
834   case BO_GT:
835   case BO_LE:
836   case BO_GE:
837   case BO_EQ:
838   case BO_NE:
839     // For comparison operators, just negate the comparison.
840     if (BinOp->getOperatorLoc().isMacroID())
841       return true;
842     Fixes.push_back(FixItHint::CreateReplacement(
843         BinOp->getOperatorLoc(),
844         BinaryOperator::getOpcodeStr(
845             BinaryOperator::negateComparisonOp(BinOp->getOpcode()))));
846     return false;
847   default:
848     // for any other binary operator, just use logical not and wrap in
849     // parens.
850     if (Parens) {
851       if (Parens->getBeginLoc().isMacroID())
852         return true;
853       Fixes.push_back(FixItHint::CreateInsertion(Parens->getBeginLoc(), "!"));
854     } else {
855       if (BinOp->getBeginLoc().isMacroID() || BinOp->getEndLoc().isMacroID())
856         return true;
857       Fixes.append({FixItHint::CreateInsertion(BinOp->getBeginLoc(), "!("),
858                     FixItHint::CreateInsertion(
859                         Lexer::getLocForEndOfToken(BinOp->getEndLoc(), 0,
860                                                    Ctx.getSourceManager(),
861                                                    Ctx.getLangOpts()),
862                         ")")});
863     }
864     break;
865   }
866   return false;
867 }
868 
flipDemorganSide(SmallVectorImpl<FixItHint> & Fixes,const ASTContext & Ctx,const Expr * E,Optional<BinaryOperatorKind> OuterBO)869 static bool flipDemorganSide(SmallVectorImpl<FixItHint> &Fixes,
870                              const ASTContext &Ctx, const Expr *E,
871                              Optional<BinaryOperatorKind> OuterBO) {
872   if (isa<UnaryOperator>(E) && cast<UnaryOperator>(E)->getOpcode() == UO_LNot) {
873     //  if we have a not operator, '!a', just remove the '!'.
874     if (cast<UnaryOperator>(E)->getOperatorLoc().isMacroID())
875       return true;
876     Fixes.push_back(
877         FixItHint::CreateRemoval(cast<UnaryOperator>(E)->getOperatorLoc()));
878     return false;
879   }
880   if (const auto *BinOp = dyn_cast<BinaryOperator>(E)) {
881     return flipDemorganBinaryOperator(Fixes, Ctx, BinOp, OuterBO);
882   }
883   if (const auto *Paren = dyn_cast<ParenExpr>(E)) {
884     if (const auto *BinOp = dyn_cast<BinaryOperator>(Paren->getSubExpr())) {
885       return flipDemorganBinaryOperator(Fixes, Ctx, BinOp, OuterBO, Paren);
886     }
887   }
888   // Fallback case just insert a logical not operator.
889   if (E->getBeginLoc().isMacroID())
890     return true;
891   Fixes.push_back(FixItHint::CreateInsertion(E->getBeginLoc(), "!"));
892   return false;
893 }
894 
shouldRemoveParens(const Stmt * Parent,BinaryOperatorKind NewOuterBinary,const ParenExpr * Parens)895 static bool shouldRemoveParens(const Stmt *Parent,
896                                BinaryOperatorKind NewOuterBinary,
897                                const ParenExpr *Parens) {
898   if (!Parens)
899     return false;
900   if (!Parent)
901     return true;
902   switch (Parent->getStmtClass()) {
903   case Stmt::BinaryOperatorClass: {
904     const auto *BO = cast<BinaryOperator>(Parent);
905     if (BO->isAssignmentOp())
906       return true;
907     if (BO->isCommaOp())
908       return true;
909     if (BO->getOpcode() == NewOuterBinary)
910       return true;
911     return false;
912   }
913   case Stmt::UnaryOperatorClass:
914   case Stmt::CXXRewrittenBinaryOperatorClass:
915     return false;
916   default:
917     return true;
918   }
919 }
920 
reportDeMorgan(const ASTContext & Context,const UnaryOperator * Outer,const BinaryOperator * Inner,bool TryOfferFix,const Stmt * Parent,const ParenExpr * Parens)921 bool SimplifyBooleanExprCheck::reportDeMorgan(const ASTContext &Context,
922                                               const UnaryOperator *Outer,
923                                               const BinaryOperator *Inner,
924                                               bool TryOfferFix,
925                                               const Stmt *Parent,
926                                               const ParenExpr *Parens) {
927   assert(Outer);
928   assert(Inner);
929   assert(Inner->isLogicalOp());
930 
931   auto Diag =
932       diag(Outer->getBeginLoc(),
933            "boolean expression can be simplified by DeMorgan's theorem");
934   Diag << Outer->getSourceRange();
935   // If we have already fixed this with a previous fix, don't attempt any fixes
936   if (!TryOfferFix)
937     return false;
938   if (Outer->getOperatorLoc().isMacroID())
939     return false;
940   SmallVector<FixItHint> Fixes;
941   auto NewOpcode = getDemorganFlippedOperator(Inner->getOpcode());
942   if (shouldRemoveParens(Parent, NewOpcode, Parens)) {
943     Fixes.push_back(FixItHint::CreateRemoval(
944         SourceRange(Outer->getOperatorLoc(), Parens->getLParen())));
945     Fixes.push_back(FixItHint::CreateRemoval(Parens->getRParen()));
946   } else {
947     Fixes.push_back(FixItHint::CreateRemoval(Outer->getOperatorLoc()));
948   }
949   if (flipDemorganOperator(Fixes, Inner))
950     return false;
951   if (flipDemorganSide(Fixes, Context, Inner->getLHS(), NewOpcode) ||
952       flipDemorganSide(Fixes, Context, Inner->getRHS(), NewOpcode))
953     return false;
954   Diag << Fixes;
955   return true;
956 }
957 } // namespace readability
958 } // namespace tidy
959 } // namespace clang
960