1 //===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 /// 10 /// This file contains functions which are used to decide if a loop worth to be 11 /// unrolled. Moreover, these functions manages the stack of loop which is 12 /// tracked by the ProgramState. 13 /// 14 //===----------------------------------------------------------------------===// 15 16 #include "clang/ASTMatchers/ASTMatchers.h" 17 #include "clang/ASTMatchers/ASTMatchFinder.h" 18 #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h" 19 #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h" 20 #include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h" 21 22 using namespace clang; 23 using namespace ento; 24 using namespace clang::ast_matchers; 25 26 static const int MAXIMUM_STEP_UNROLLED = 128; 27 28 struct LoopState { 29 private: 30 enum Kind { Normal, Unrolled } K; 31 const Stmt *LoopStmt; 32 const LocationContext *LCtx; 33 unsigned maxStep; 34 LoopState(Kind InK, const Stmt *S, const LocationContext *L, unsigned N) 35 : K(InK), LoopStmt(S), LCtx(L), maxStep(N) {} 36 37 public: 38 static LoopState getNormal(const Stmt *S, const LocationContext *L, 39 unsigned N) { 40 return LoopState(Normal, S, L, N); 41 } 42 static LoopState getUnrolled(const Stmt *S, const LocationContext *L, 43 unsigned N) { 44 return LoopState(Unrolled, S, L, N); 45 } 46 bool isUnrolled() const { return K == Unrolled; } 47 unsigned getMaxStep() const { return maxStep; } 48 const Stmt *getLoopStmt() const { return LoopStmt; } 49 const LocationContext *getLocationContext() const { return LCtx; } 50 bool operator==(const LoopState &X) const { 51 return K == X.K && LoopStmt == X.LoopStmt; 52 } 53 void Profile(llvm::FoldingSetNodeID &ID) const { 54 ID.AddInteger(K); 55 ID.AddPointer(LoopStmt); 56 ID.AddPointer(LCtx); 57 ID.AddInteger(maxStep); 58 } 59 }; 60 61 // The tracked stack of loops. The stack indicates that which loops the 62 // simulated element contained by. The loops are marked depending if we decided 63 // to unroll them. 64 // TODO: The loop stack should not need to be in the program state since it is 65 // lexical in nature. Instead, the stack of loops should be tracked in the 66 // LocationContext. 67 REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState) 68 69 namespace clang { 70 namespace ento { 71 72 static bool isLoopStmt(const Stmt *S) { 73 return S && (isa<ForStmt>(S) || isa<WhileStmt>(S) || isa<DoStmt>(S)); 74 } 75 76 ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State) { 77 auto LS = State->get<LoopStack>(); 78 if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt) 79 State = State->set<LoopStack>(LS.getTail()); 80 return State; 81 } 82 83 static internal::Matcher<Stmt> simpleCondition(StringRef BindName) { 84 return binaryOperator(anyOf(hasOperatorName("<"), hasOperatorName(">"), 85 hasOperatorName("<="), hasOperatorName(">="), 86 hasOperatorName("!=")), 87 hasEitherOperand(ignoringParenImpCasts(declRefExpr( 88 to(varDecl(hasType(isInteger())).bind(BindName))))), 89 hasEitherOperand(ignoringParenImpCasts( 90 integerLiteral().bind("boundNum")))) 91 .bind("conditionOperator"); 92 } 93 94 static internal::Matcher<Stmt> 95 changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher) { 96 return anyOf( 97 unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")), 98 hasUnaryOperand(ignoringParenImpCasts( 99 declRefExpr(to(varDecl(VarNodeMatcher)))))), 100 binaryOperator(anyOf(hasOperatorName("="), hasOperatorName("+="), 101 hasOperatorName("/="), hasOperatorName("*="), 102 hasOperatorName("-=")), 103 hasLHS(ignoringParenImpCasts( 104 declRefExpr(to(varDecl(VarNodeMatcher))))))); 105 } 106 107 static internal::Matcher<Stmt> 108 callByRef(internal::Matcher<Decl> VarNodeMatcher) { 109 return callExpr(forEachArgumentWithParam( 110 declRefExpr(to(varDecl(VarNodeMatcher))), 111 parmVarDecl(hasType(references(qualType(unless(isConstQualified()))))))); 112 } 113 114 static internal::Matcher<Stmt> 115 assignedToRef(internal::Matcher<Decl> VarNodeMatcher) { 116 return declStmt(hasDescendant(varDecl( 117 allOf(hasType(referenceType()), 118 hasInitializer(anyOf( 119 initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher))))), 120 declRefExpr(to(varDecl(VarNodeMatcher))))))))); 121 } 122 123 static internal::Matcher<Stmt> 124 getAddrTo(internal::Matcher<Decl> VarNodeMatcher) { 125 return unaryOperator( 126 hasOperatorName("&"), 127 hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher)))); 128 } 129 130 static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) { 131 return hasDescendant(stmt( 132 anyOf(gotoStmt(), switchStmt(), returnStmt(), 133 // Escaping and not known mutation of the loop counter is handled 134 // by exclusion of assigning and address-of operators and 135 // pass-by-ref function calls on the loop counter from the body. 136 changeIntBoundNode(equalsBoundNode(NodeName)), 137 callByRef(equalsBoundNode(NodeName)), 138 getAddrTo(equalsBoundNode(NodeName)), 139 assignedToRef(equalsBoundNode(NodeName))))); 140 } 141 142 static internal::Matcher<Stmt> forLoopMatcher() { 143 return forStmt( 144 hasCondition(simpleCondition("initVarName")), 145 // Initialization should match the form: 'int i = 6' or 'i = 42'. 146 hasLoopInit(anyOf( 147 declStmt(hasSingleDecl(varDecl( 148 allOf(hasInitializer(integerLiteral().bind("initNum")), 149 equalsBoundNode("initVarName"))))), 150 binaryOperator(hasLHS(declRefExpr(to( 151 varDecl(equalsBoundNode("initVarName"))))), 152 hasRHS(integerLiteral().bind("initNum"))))), 153 // Incrementation should be a simple increment or decrement 154 // operator call. 155 hasIncrement(unaryOperator( 156 anyOf(hasOperatorName("++"), hasOperatorName("--")), 157 hasUnaryOperand(declRefExpr( 158 to(varDecl(allOf(equalsBoundNode("initVarName"), 159 hasType(isInteger())))))))), 160 unless(hasBody(hasSuspiciousStmt("initVarName")))).bind("forLoop"); 161 } 162 163 static bool isPossiblyEscaped(const VarDecl *VD, ExplodedNode *N) { 164 // Global variables assumed as escaped variables. 165 if (VD->hasGlobalStorage()) 166 return true; 167 168 while (!N->pred_empty()) { 169 const Stmt *S = PathDiagnosticLocation::getStmt(N); 170 if (!S) { 171 N = N->getFirstPred(); 172 continue; 173 } 174 175 if (const DeclStmt *DS = dyn_cast<DeclStmt>(S)) { 176 for (const Decl *D : DS->decls()) { 177 // Once we reach the declaration of the VD we can return. 178 if (D->getCanonicalDecl() == VD) 179 return false; 180 } 181 } 182 // Check the usage of the pass-by-ref function calls and adress-of operator 183 // on VD and reference initialized by VD. 184 ASTContext &ASTCtx = 185 N->getLocationContext()->getAnalysisDeclContext()->getASTContext(); 186 auto Match = 187 match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)), 188 assignedToRef(equalsNode(VD)))), 189 *S, ASTCtx); 190 if (!Match.empty()) 191 return true; 192 193 N = N->getFirstPred(); 194 } 195 llvm_unreachable("Reached root without finding the declaration of VD"); 196 } 197 198 bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx, 199 ExplodedNode *Pred, unsigned &maxStep) { 200 201 if (!isLoopStmt(LoopStmt)) 202 return false; 203 204 // TODO: Match the cases where the bound is not a concrete literal but an 205 // integer with known value 206 auto Matches = match(forLoopMatcher(), *LoopStmt, ASTCtx); 207 if (Matches.empty()) 208 return false; 209 210 auto CounterVar = Matches[0].getNodeAs<VarDecl>("initVarName"); 211 llvm::APInt BoundNum = 212 Matches[0].getNodeAs<IntegerLiteral>("boundNum")->getValue(); 213 llvm::APInt InitNum = 214 Matches[0].getNodeAs<IntegerLiteral>("initNum")->getValue(); 215 auto CondOp = Matches[0].getNodeAs<BinaryOperator>("conditionOperator"); 216 if (InitNum.getBitWidth() != BoundNum.getBitWidth()) { 217 InitNum = InitNum.zextOrSelf(BoundNum.getBitWidth()); 218 BoundNum = BoundNum.zextOrSelf(InitNum.getBitWidth()); 219 } 220 221 if (CondOp->getOpcode() == BO_GE || CondOp->getOpcode() == BO_LE) 222 maxStep = (BoundNum - InitNum + 1).abs().getZExtValue(); 223 else 224 maxStep = (BoundNum - InitNum).abs().getZExtValue(); 225 226 // Check if the counter of the loop is not escaped before. 227 return !isPossiblyEscaped(CounterVar->getCanonicalDecl(), Pred); 228 } 229 230 bool madeNewBranch(ExplodedNode *N, const Stmt *LoopStmt) { 231 const Stmt *S = nullptr; 232 while (!N->pred_empty()) { 233 if (N->succ_size() > 1) 234 return true; 235 236 ProgramPoint P = N->getLocation(); 237 if (Optional<BlockEntrance> BE = P.getAs<BlockEntrance>()) 238 S = BE->getBlock()->getTerminator(); 239 240 if (S == LoopStmt) 241 return false; 242 243 N = N->getFirstPred(); 244 } 245 246 llvm_unreachable("Reached root without encountering the previous step"); 247 } 248 249 // updateLoopStack is called on every basic block, therefore it needs to be fast 250 ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx, 251 ExplodedNode *Pred, unsigned maxVisitOnPath) { 252 auto State = Pred->getState(); 253 auto LCtx = Pred->getLocationContext(); 254 255 if (!isLoopStmt(LoopStmt)) 256 return State; 257 258 auto LS = State->get<LoopStack>(); 259 if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() && 260 LCtx == LS.getHead().getLocationContext()) { 261 if (LS.getHead().isUnrolled() && madeNewBranch(Pred, LoopStmt)) { 262 State = State->set<LoopStack>(LS.getTail()); 263 State = State->add<LoopStack>( 264 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath)); 265 } 266 return State; 267 } 268 unsigned maxStep; 269 if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred, maxStep)) { 270 State = State->add<LoopStack>( 271 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath)); 272 return State; 273 } 274 275 unsigned outerStep = (LS.isEmpty() ? 1 : LS.getHead().getMaxStep()); 276 277 unsigned innerMaxStep = maxStep * outerStep; 278 if (innerMaxStep > MAXIMUM_STEP_UNROLLED) 279 State = State->add<LoopStack>( 280 LoopState::getNormal(LoopStmt, LCtx, maxVisitOnPath)); 281 else 282 State = State->add<LoopStack>( 283 LoopState::getUnrolled(LoopStmt, LCtx, innerMaxStep)); 284 return State; 285 } 286 287 bool isUnrolledState(ProgramStateRef State) { 288 auto LS = State->get<LoopStack>(); 289 if (LS.isEmpty() || !LS.getHead().isUnrolled()) 290 return false; 291 return true; 292 } 293 } 294 } 295