1 //===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 ///
10 /// This file contains functions which are used to decide if a loop worth to be
11 /// unrolled. Moreover, these functions manages the stack of loop which is
12 /// tracked by the ProgramState.
13 ///
14 //===----------------------------------------------------------------------===//
15 
16 #include "clang/Analysis/CFGStmtMap.h"
17 #include "clang/ASTMatchers/ASTMatchers.h"
18 #include "clang/ASTMatchers/ASTMatchFinder.h"
19 #include "clang/AST/ParentMap.h"
20 #include "clang/AST/StmtVisitor.h"
21 #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
22 #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
23 #include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h"
24 #include "llvm/ADT/Statistic.h"
25 
26 using namespace clang;
27 using namespace ento;
28 using namespace clang::ast_matchers;
29 
30 struct LoopState {
31 private:
32   enum Kind { Normal, Unrolled } K;
33   const Stmt *LoopStmt;
34   const LocationContext *LCtx;
35   LoopState(Kind InK, const Stmt *S, const LocationContext *L)
36       : K(InK), LoopStmt(S), LCtx(L) {}
37 
38 public:
39   static LoopState getNormal(const Stmt *S, const LocationContext *L) {
40     return LoopState(Normal, S, L);
41   }
42   static LoopState getUnrolled(const Stmt *S, const LocationContext *L) {
43     return LoopState(Unrolled, S, L);
44   }
45   bool isUnrolled() const { return K == Unrolled; }
46   const Stmt *getLoopStmt() const { return LoopStmt; }
47   const LocationContext *getLocationContext() const { return LCtx; }
48   bool operator==(const LoopState &X) const {
49     return K == X.K && LoopStmt == X.LoopStmt;
50   }
51   void Profile(llvm::FoldingSetNodeID &ID) const {
52     ID.AddInteger(K);
53     ID.AddPointer(LoopStmt);
54     ID.AddPointer(LCtx);
55   }
56 };
57 
58 // The tracked stack of loops. The stack indicates that which loops the
59 // simulated element contained by. The loops are marked depending if we decided
60 // to unroll them.
61 // TODO: The loop stack should not need to be in the program state since it is
62 // lexical in nature. Instead, the stack of loops should be tracked in the
63 // LocationContext.
64 REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState)
65 
66 namespace clang {
67 namespace ento {
68 
69 static bool isLoopStmt(const Stmt *S) {
70   return S && (isa<ForStmt>(S) || isa<WhileStmt>(S) || isa<DoStmt>(S));
71 }
72 
73 ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State) {
74   auto LS = State->get<LoopStack>();
75   assert(!LS.isEmpty() && "Loop not added to the stack.");
76   assert(LoopStmt == LS.getHead().getLoopStmt() &&
77          "Loop is not on top of the stack.");
78 
79   State = State->set<LoopStack>(LS.getTail());
80   return State;
81 }
82 
83 static internal::Matcher<Stmt> simpleCondition(StringRef BindName) {
84   return binaryOperator(
85       anyOf(hasOperatorName("<"), hasOperatorName(">"), hasOperatorName("<="),
86             hasOperatorName(">="), hasOperatorName("!=")),
87       hasEitherOperand(ignoringParenImpCasts(
88           declRefExpr(to(varDecl(hasType(isInteger())).bind(BindName))))),
89       hasEitherOperand(ignoringParenImpCasts(integerLiteral())));
90 }
91 
92 static internal::Matcher<Stmt>
93 changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher) {
94   return anyOf(
95       unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")),
96                     hasUnaryOperand(ignoringParenImpCasts(
97                         declRefExpr(to(varDecl(VarNodeMatcher)))))),
98       binaryOperator(anyOf(hasOperatorName("="), hasOperatorName("+="),
99                            hasOperatorName("/="), hasOperatorName("*="),
100                            hasOperatorName("-=")),
101                      hasLHS(ignoringParenImpCasts(
102                          declRefExpr(to(varDecl(VarNodeMatcher)))))));
103 }
104 
105 static internal::Matcher<Stmt>
106 callByRef(internal::Matcher<Decl> VarNodeMatcher) {
107   return callExpr(forEachArgumentWithParam(
108       declRefExpr(to(varDecl(VarNodeMatcher))),
109       parmVarDecl(hasType(references(qualType(unless(isConstQualified())))))));
110 }
111 
112 static internal::Matcher<Stmt>
113 assignedToRef(internal::Matcher<Decl> VarNodeMatcher) {
114   return declStmt(hasDescendant(varDecl(
115       allOf(hasType(referenceType()),
116             hasInitializer(anyOf(
117                 initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher))))),
118                 declRefExpr(to(varDecl(VarNodeMatcher)))))))));
119 }
120 
121 static internal::Matcher<Stmt>
122 getAddrTo(internal::Matcher<Decl> VarNodeMatcher) {
123   return unaryOperator(
124       hasOperatorName("&"),
125       hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher))));
126 }
127 
128 static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) {
129   return hasDescendant(stmt(
130       anyOf(gotoStmt(), switchStmt(), returnStmt(),
131             // Escaping and not known mutation of the loop counter is handled
132             // by exclusion of assigning and address-of operators and
133             // pass-by-ref function calls on the loop counter from the body.
134             changeIntBoundNode(equalsBoundNode(NodeName)),
135             callByRef(equalsBoundNode(NodeName)),
136             getAddrTo(equalsBoundNode(NodeName)),
137             assignedToRef(equalsBoundNode(NodeName)))));
138 }
139 
140 static internal::Matcher<Stmt> forLoopMatcher() {
141   return forStmt(
142              hasCondition(simpleCondition("initVarName")),
143              // Initialization should match the form: 'int i = 6' or 'i = 42'.
144              hasLoopInit(
145                  anyOf(declStmt(hasSingleDecl(
146                            varDecl(allOf(hasInitializer(integerLiteral()),
147                                          equalsBoundNode("initVarName"))))),
148                        binaryOperator(hasLHS(declRefExpr(to(varDecl(
149                                           equalsBoundNode("initVarName"))))),
150                                       hasRHS(integerLiteral())))),
151              // Incrementation should be a simple increment or decrement
152              // operator call.
153              hasIncrement(unaryOperator(
154                  anyOf(hasOperatorName("++"), hasOperatorName("--")),
155                  hasUnaryOperand(declRefExpr(
156                      to(varDecl(allOf(equalsBoundNode("initVarName"),
157                                       hasType(isInteger())))))))),
158              unless(hasBody(hasSuspiciousStmt("initVarName")))).bind("forLoop");
159 }
160 
161 static bool isPossiblyEscaped(const VarDecl *VD, ExplodedNode *N) {
162   // Global variables assumed as escaped variables.
163   if (VD->hasGlobalStorage())
164     return true;
165 
166   while (!N->pred_empty()) {
167     const Stmt *S = PathDiagnosticLocation::getStmt(N);
168     if (!S) {
169       N = N->getFirstPred();
170       continue;
171     }
172 
173     if (const DeclStmt *DS = dyn_cast<DeclStmt>(S)) {
174       for (const Decl *D : DS->decls()) {
175         // Once we reach the declaration of the VD we can return.
176         if (D->getCanonicalDecl() == VD)
177           return false;
178       }
179     }
180     // Check the usage of the pass-by-ref function calls and adress-of operator
181     // on VD and reference initialized by VD.
182     ASTContext &ASTCtx =
183         N->getLocationContext()->getAnalysisDeclContext()->getASTContext();
184     auto Match =
185         match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)),
186                          assignedToRef(equalsNode(VD)))),
187               *S, ASTCtx);
188     if (!Match.empty())
189       return true;
190 
191     N = N->getFirstPred();
192   }
193   llvm_unreachable("Reached root without finding the declaration of VD");
194 }
195 
196 bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx,
197                             ExplodedNode *Pred) {
198 
199   if (!isLoopStmt(LoopStmt))
200     return false;
201 
202   // TODO: Match the cases where the bound is not a concrete literal but an
203   // integer with known value
204   auto Matches = match(forLoopMatcher(), *LoopStmt, ASTCtx);
205   if (Matches.empty())
206     return false;
207 
208   auto CounterVar = Matches[0].getNodeAs<VarDecl>("initVarName");
209 
210   // Check if the counter of the loop is not escaped before.
211   return !isPossiblyEscaped(CounterVar->getCanonicalDecl(), Pred);
212 }
213 
214 // updateLoopStack is called on every basic block, therefore it needs to be fast
215 ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx,
216                                 ExplodedNode* Pred) {
217   auto State = Pred->getState();
218   auto LCtx = Pred->getLocationContext();
219 
220   if (!isLoopStmt(LoopStmt))
221     return State;
222 
223   auto LS = State->get<LoopStack>();
224   if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() &&
225       LCtx == LS.getHead().getLocationContext())
226     return State;
227 
228   if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred)) {
229     State = State->add<LoopStack>(LoopState::getNormal(LoopStmt, LCtx));
230     return State;
231   }
232 
233   State = State->add<LoopStack>(LoopState::getUnrolled(LoopStmt, LCtx));
234   return State;
235 }
236 
237 bool isUnrolledState(ProgramStateRef State) {
238   auto LS = State->get<LoopStack>();
239   if (LS.isEmpty() || !LS.getHead().isUnrolled())
240     return false;
241   return true;
242 }
243 }
244 }
245