1 //===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 /// 10 /// This file contains functions which are used to decide if a loop worth to be 11 /// unrolled. Moreover, these functions manages the stack of loop which is 12 /// tracked by the ProgramState. 13 /// 14 //===----------------------------------------------------------------------===// 15 16 #include "clang/Analysis/CFGStmtMap.h" 17 #include "clang/ASTMatchers/ASTMatchers.h" 18 #include "clang/ASTMatchers/ASTMatchFinder.h" 19 #include "clang/AST/ParentMap.h" 20 #include "clang/AST/StmtVisitor.h" 21 #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h" 22 #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h" 23 #include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h" 24 #include "llvm/ADT/Statistic.h" 25 26 using namespace clang; 27 using namespace ento; 28 using namespace clang::ast_matchers; 29 30 struct LoopState { 31 private: 32 enum Kind { Normal, Unrolled } K; 33 const Stmt *LoopStmt; 34 const LocationContext *LCtx; 35 LoopState(Kind InK, const Stmt *S, const LocationContext *L) 36 : K(InK), LoopStmt(S), LCtx(L) {} 37 38 public: 39 static LoopState getNormal(const Stmt *S, const LocationContext *L) { 40 return LoopState(Normal, S, L); 41 } 42 static LoopState getUnrolled(const Stmt *S, const LocationContext *L) { 43 return LoopState(Unrolled, S, L); 44 } 45 bool isUnrolled() const { return K == Unrolled; } 46 const Stmt *getLoopStmt() const { return LoopStmt; } 47 const LocationContext *getLocationContext() const { return LCtx; } 48 bool operator==(const LoopState &X) const { 49 return K == X.K && LoopStmt == X.LoopStmt; 50 } 51 void Profile(llvm::FoldingSetNodeID &ID) const { 52 ID.AddInteger(K); 53 ID.AddPointer(LoopStmt); 54 ID.AddPointer(LCtx); 55 } 56 }; 57 58 // The tracked stack of loops. The stack indicates that which loops the 59 // simulated element contained by. The loops are marked depending if we decided 60 // to unroll them. 61 // TODO: The loop stack should not need to be in the program state since it is 62 // lexical in nature. Instead, the stack of loops should be tracked in the 63 // LocationContext. 64 REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState) 65 66 namespace clang { 67 namespace ento { 68 69 static bool isLoopStmt(const Stmt *S) { 70 return S && (isa<ForStmt>(S) || isa<WhileStmt>(S) || isa<DoStmt>(S)); 71 } 72 73 ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State) { 74 auto LS = State->get<LoopStack>(); 75 assert(!LS.isEmpty() && "Loop not added to the stack."); 76 assert(LoopStmt == LS.getHead().getLoopStmt() && 77 "Loop is not on top of the stack."); 78 79 State = State->set<LoopStack>(LS.getTail()); 80 return State; 81 } 82 83 static internal::Matcher<Stmt> simpleCondition(StringRef BindName) { 84 return binaryOperator( 85 anyOf(hasOperatorName("<"), hasOperatorName(">"), hasOperatorName("<="), 86 hasOperatorName(">="), hasOperatorName("!=")), 87 hasEitherOperand(ignoringParenImpCasts( 88 declRefExpr(to(varDecl(hasType(isInteger())).bind(BindName))))), 89 hasEitherOperand(ignoringParenImpCasts(integerLiteral()))); 90 } 91 92 static internal::Matcher<Stmt> 93 changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher) { 94 return anyOf( 95 unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")), 96 hasUnaryOperand(ignoringParenImpCasts( 97 declRefExpr(to(varDecl(VarNodeMatcher)))))), 98 binaryOperator(anyOf(hasOperatorName("="), hasOperatorName("+="), 99 hasOperatorName("/="), hasOperatorName("*="), 100 hasOperatorName("-=")), 101 hasLHS(ignoringParenImpCasts( 102 declRefExpr(to(varDecl(VarNodeMatcher))))))); 103 } 104 105 static internal::Matcher<Stmt> 106 callByRef(internal::Matcher<Decl> VarNodeMatcher) { 107 return callExpr(forEachArgumentWithParam( 108 declRefExpr(to(varDecl(VarNodeMatcher))), 109 parmVarDecl(hasType(references(qualType(unless(isConstQualified()))))))); 110 } 111 112 static internal::Matcher<Stmt> 113 assignedToRef(internal::Matcher<Decl> VarNodeMatcher) { 114 return declStmt(hasDescendant(varDecl( 115 allOf(hasType(referenceType()), 116 hasInitializer(anyOf( 117 initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher))))), 118 declRefExpr(to(varDecl(VarNodeMatcher))))))))); 119 } 120 121 static internal::Matcher<Stmt> 122 getAddrTo(internal::Matcher<Decl> VarNodeMatcher) { 123 return unaryOperator( 124 hasOperatorName("&"), 125 hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher)))); 126 } 127 128 static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) { 129 return hasDescendant(stmt( 130 anyOf(gotoStmt(), switchStmt(), returnStmt(), 131 // Escaping and not known mutation of the loop counter is handled 132 // by exclusion of assigning and address-of operators and 133 // pass-by-ref function calls on the loop counter from the body. 134 changeIntBoundNode(equalsBoundNode(NodeName)), 135 callByRef(equalsBoundNode(NodeName)), 136 getAddrTo(equalsBoundNode(NodeName)), 137 assignedToRef(equalsBoundNode(NodeName))))); 138 } 139 140 static internal::Matcher<Stmt> forLoopMatcher() { 141 return forStmt( 142 hasCondition(simpleCondition("initVarName")), 143 // Initialization should match the form: 'int i = 6' or 'i = 42'. 144 hasLoopInit( 145 anyOf(declStmt(hasSingleDecl( 146 varDecl(allOf(hasInitializer(integerLiteral()), 147 equalsBoundNode("initVarName"))))), 148 binaryOperator(hasLHS(declRefExpr(to(varDecl( 149 equalsBoundNode("initVarName"))))), 150 hasRHS(integerLiteral())))), 151 // Incrementation should be a simple increment or decrement 152 // operator call. 153 hasIncrement(unaryOperator( 154 anyOf(hasOperatorName("++"), hasOperatorName("--")), 155 hasUnaryOperand(declRefExpr( 156 to(varDecl(allOf(equalsBoundNode("initVarName"), 157 hasType(isInteger())))))))), 158 unless(hasBody(hasSuspiciousStmt("initVarName")))).bind("forLoop"); 159 } 160 161 static bool isPossiblyEscaped(const VarDecl *VD, ExplodedNode *N) { 162 // Global variables assumed as escaped variables. 163 if (VD->hasGlobalStorage()) 164 return true; 165 166 while (!N->pred_empty()) { 167 const Stmt *S = PathDiagnosticLocation::getStmt(N); 168 if (!S) { 169 N = N->getFirstPred(); 170 continue; 171 } 172 173 if (const DeclStmt *DS = dyn_cast<DeclStmt>(S)) { 174 for (const Decl *D : DS->decls()) { 175 // Once we reach the declaration of the VD we can return. 176 if (D->getCanonicalDecl() == VD) 177 return false; 178 } 179 } 180 // Check the usage of the pass-by-ref function calls and adress-of operator 181 // on VD and reference initialized by VD. 182 ASTContext &ASTCtx = 183 N->getLocationContext()->getAnalysisDeclContext()->getASTContext(); 184 auto Match = 185 match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)), 186 assignedToRef(equalsNode(VD)))), 187 *S, ASTCtx); 188 if (!Match.empty()) 189 return true; 190 191 N = N->getFirstPred(); 192 } 193 llvm_unreachable("Reached root without finding the declaration of VD"); 194 } 195 196 bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx, 197 ExplodedNode *Pred) { 198 199 if (!isLoopStmt(LoopStmt)) 200 return false; 201 202 // TODO: Match the cases where the bound is not a concrete literal but an 203 // integer with known value 204 auto Matches = match(forLoopMatcher(), *LoopStmt, ASTCtx); 205 if (Matches.empty()) 206 return false; 207 208 auto CounterVar = Matches[0].getNodeAs<VarDecl>("initVarName"); 209 210 // Check if the counter of the loop is not escaped before. 211 return !isPossiblyEscaped(CounterVar->getCanonicalDecl(), Pred); 212 } 213 214 // updateLoopStack is called on every basic block, therefore it needs to be fast 215 ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx, 216 ExplodedNode* Pred) { 217 auto State = Pred->getState(); 218 auto LCtx = Pred->getLocationContext(); 219 220 if (!isLoopStmt(LoopStmt)) 221 return State; 222 223 auto LS = State->get<LoopStack>(); 224 if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() && 225 LCtx == LS.getHead().getLocationContext()) 226 return State; 227 228 if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred)) { 229 State = State->add<LoopStack>(LoopState::getNormal(LoopStmt, LCtx)); 230 return State; 231 } 232 233 State = State->add<LoopStack>(LoopState::getUnrolled(LoopStmt, LCtx)); 234 return State; 235 } 236 237 bool isUnrolledState(ProgramStateRef State) { 238 auto LS = State->get<LoopStack>(); 239 if (LS.isEmpty() || !LS.getHead().isUnrolled()) 240 return false; 241 return true; 242 } 243 } 244 } 245