1 //===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 /// 10 /// This file contains functions which are used to decide if a loop worth to be 11 /// unrolled. Moreover, these functions manages the stack of loop which is 12 /// tracked by the ProgramState. 13 /// 14 //===----------------------------------------------------------------------===// 15 16 #include "clang/ASTMatchers/ASTMatchers.h" 17 #include "clang/ASTMatchers/ASTMatchFinder.h" 18 #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h" 19 #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h" 20 #include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h" 21 22 using namespace clang; 23 using namespace ento; 24 using namespace clang::ast_matchers; 25 26 struct LoopState { 27 private: 28 enum Kind { Normal, Unrolled } K; 29 const Stmt *LoopStmt; 30 const LocationContext *LCtx; 31 LoopState(Kind InK, const Stmt *S, const LocationContext *L) 32 : K(InK), LoopStmt(S), LCtx(L) {} 33 34 public: 35 static LoopState getNormal(const Stmt *S, const LocationContext *L) { 36 return LoopState(Normal, S, L); 37 } 38 static LoopState getUnrolled(const Stmt *S, const LocationContext *L) { 39 return LoopState(Unrolled, S, L); 40 } 41 bool isUnrolled() const { return K == Unrolled; } 42 const Stmt *getLoopStmt() const { return LoopStmt; } 43 const LocationContext *getLocationContext() const { return LCtx; } 44 bool operator==(const LoopState &X) const { 45 return K == X.K && LoopStmt == X.LoopStmt; 46 } 47 void Profile(llvm::FoldingSetNodeID &ID) const { 48 ID.AddInteger(K); 49 ID.AddPointer(LoopStmt); 50 ID.AddPointer(LCtx); 51 } 52 }; 53 54 // The tracked stack of loops. The stack indicates that which loops the 55 // simulated element contained by. The loops are marked depending if we decided 56 // to unroll them. 57 // TODO: The loop stack should not need to be in the program state since it is 58 // lexical in nature. Instead, the stack of loops should be tracked in the 59 // LocationContext. 60 REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState) 61 62 namespace clang { 63 namespace ento { 64 65 static bool isLoopStmt(const Stmt *S) { 66 return S && (isa<ForStmt>(S) || isa<WhileStmt>(S) || isa<DoStmt>(S)); 67 } 68 69 ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State) { 70 auto LS = State->get<LoopStack>(); 71 if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt) 72 State = State->set<LoopStack>(LS.getTail()); 73 return State; 74 } 75 76 static internal::Matcher<Stmt> simpleCondition(StringRef BindName) { 77 return binaryOperator( 78 anyOf(hasOperatorName("<"), hasOperatorName(">"), hasOperatorName("<="), 79 hasOperatorName(">="), hasOperatorName("!=")), 80 hasEitherOperand(ignoringParenImpCasts( 81 declRefExpr(to(varDecl(hasType(isInteger())).bind(BindName))))), 82 hasEitherOperand(ignoringParenImpCasts(integerLiteral()))); 83 } 84 85 static internal::Matcher<Stmt> 86 changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher) { 87 return anyOf( 88 unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")), 89 hasUnaryOperand(ignoringParenImpCasts( 90 declRefExpr(to(varDecl(VarNodeMatcher)))))), 91 binaryOperator(anyOf(hasOperatorName("="), hasOperatorName("+="), 92 hasOperatorName("/="), hasOperatorName("*="), 93 hasOperatorName("-=")), 94 hasLHS(ignoringParenImpCasts( 95 declRefExpr(to(varDecl(VarNodeMatcher))))))); 96 } 97 98 static internal::Matcher<Stmt> 99 callByRef(internal::Matcher<Decl> VarNodeMatcher) { 100 return callExpr(forEachArgumentWithParam( 101 declRefExpr(to(varDecl(VarNodeMatcher))), 102 parmVarDecl(hasType(references(qualType(unless(isConstQualified()))))))); 103 } 104 105 static internal::Matcher<Stmt> 106 assignedToRef(internal::Matcher<Decl> VarNodeMatcher) { 107 return declStmt(hasDescendant(varDecl( 108 allOf(hasType(referenceType()), 109 hasInitializer(anyOf( 110 initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher))))), 111 declRefExpr(to(varDecl(VarNodeMatcher))))))))); 112 } 113 114 static internal::Matcher<Stmt> 115 getAddrTo(internal::Matcher<Decl> VarNodeMatcher) { 116 return unaryOperator( 117 hasOperatorName("&"), 118 hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher)))); 119 } 120 121 static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) { 122 return hasDescendant(stmt( 123 anyOf(gotoStmt(), switchStmt(), returnStmt(), 124 // Escaping and not known mutation of the loop counter is handled 125 // by exclusion of assigning and address-of operators and 126 // pass-by-ref function calls on the loop counter from the body. 127 changeIntBoundNode(equalsBoundNode(NodeName)), 128 callByRef(equalsBoundNode(NodeName)), 129 getAddrTo(equalsBoundNode(NodeName)), 130 assignedToRef(equalsBoundNode(NodeName))))); 131 } 132 133 static internal::Matcher<Stmt> forLoopMatcher() { 134 return forStmt( 135 hasCondition(simpleCondition("initVarName")), 136 // Initialization should match the form: 'int i = 6' or 'i = 42'. 137 hasLoopInit( 138 anyOf(declStmt(hasSingleDecl( 139 varDecl(allOf(hasInitializer(integerLiteral()), 140 equalsBoundNode("initVarName"))))), 141 binaryOperator(hasLHS(declRefExpr(to(varDecl( 142 equalsBoundNode("initVarName"))))), 143 hasRHS(integerLiteral())))), 144 // Incrementation should be a simple increment or decrement 145 // operator call. 146 hasIncrement(unaryOperator( 147 anyOf(hasOperatorName("++"), hasOperatorName("--")), 148 hasUnaryOperand(declRefExpr( 149 to(varDecl(allOf(equalsBoundNode("initVarName"), 150 hasType(isInteger())))))))), 151 unless(hasBody(hasSuspiciousStmt("initVarName")))).bind("forLoop"); 152 } 153 154 static bool isPossiblyEscaped(const VarDecl *VD, ExplodedNode *N) { 155 // Global variables assumed as escaped variables. 156 if (VD->hasGlobalStorage()) 157 return true; 158 159 while (!N->pred_empty()) { 160 const Stmt *S = PathDiagnosticLocation::getStmt(N); 161 if (!S) { 162 N = N->getFirstPred(); 163 continue; 164 } 165 166 if (const DeclStmt *DS = dyn_cast<DeclStmt>(S)) { 167 for (const Decl *D : DS->decls()) { 168 // Once we reach the declaration of the VD we can return. 169 if (D->getCanonicalDecl() == VD) 170 return false; 171 } 172 } 173 // Check the usage of the pass-by-ref function calls and adress-of operator 174 // on VD and reference initialized by VD. 175 ASTContext &ASTCtx = 176 N->getLocationContext()->getAnalysisDeclContext()->getASTContext(); 177 auto Match = 178 match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)), 179 assignedToRef(equalsNode(VD)))), 180 *S, ASTCtx); 181 if (!Match.empty()) 182 return true; 183 184 N = N->getFirstPred(); 185 } 186 llvm_unreachable("Reached root without finding the declaration of VD"); 187 } 188 189 bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx, 190 ExplodedNode *Pred) { 191 192 if (!isLoopStmt(LoopStmt)) 193 return false; 194 195 // TODO: Match the cases where the bound is not a concrete literal but an 196 // integer with known value 197 auto Matches = match(forLoopMatcher(), *LoopStmt, ASTCtx); 198 if (Matches.empty()) 199 return false; 200 201 auto CounterVar = Matches[0].getNodeAs<VarDecl>("initVarName"); 202 203 // Check if the counter of the loop is not escaped before. 204 return !isPossiblyEscaped(CounterVar->getCanonicalDecl(), Pred); 205 } 206 207 bool madeNewBranch(ExplodedNode* N, const Stmt* LoopStmt) { 208 const Stmt* S = nullptr; 209 while (!N->pred_empty()) 210 { 211 if (N->succ_size() > 1) 212 return true; 213 214 ProgramPoint P = N->getLocation(); 215 if (Optional<BlockEntrance> BE = P.getAs<BlockEntrance>()) 216 S = BE->getBlock()->getTerminator(); 217 218 if (S == LoopStmt) 219 return false; 220 221 N = N->getFirstPred(); 222 } 223 224 llvm_unreachable("Reached root without encountering the previous step"); 225 } 226 227 // updateLoopStack is called on every basic block, therefore it needs to be fast 228 ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx, 229 ExplodedNode* Pred) { 230 auto State = Pred->getState(); 231 auto LCtx = Pred->getLocationContext(); 232 233 if (!isLoopStmt(LoopStmt)) 234 return State; 235 236 auto LS = State->get<LoopStack>(); 237 if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() && 238 LCtx == LS.getHead().getLocationContext()) { 239 if (LS.getHead().isUnrolled() && madeNewBranch(Pred, LoopStmt)) { 240 State = State->set<LoopStack>(LS.getTail()); 241 State = State->add<LoopStack>(LoopState::getNormal(LoopStmt, LCtx)); 242 } 243 return State; 244 } 245 246 if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred)) { 247 State = State->add<LoopStack>(LoopState::getNormal(LoopStmt, LCtx)); 248 return State; 249 } 250 251 State = State->add<LoopStack>(LoopState::getUnrolled(LoopStmt, LCtx)); 252 return State; 253 } 254 255 bool isUnrolledState(ProgramStateRef State) { 256 auto LS = State->get<LoopStack>(); 257 if (LS.isEmpty() || !LS.getHead().isUnrolled()) 258 return false; 259 return true; 260 } 261 } 262 } 263