1 //===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 ///
10 /// This file contains functions which are used to decide if a loop worth to be
11 /// unrolled. Moreover, these functions manages the stack of loop which is
12 /// tracked by the ProgramState.
13 ///
14 //===----------------------------------------------------------------------===//
15 
16 #include "clang/ASTMatchers/ASTMatchers.h"
17 #include "clang/ASTMatchers/ASTMatchFinder.h"
18 #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
19 #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
20 #include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h"
21 
22 using namespace clang;
23 using namespace ento;
24 using namespace clang::ast_matchers;
25 
26 struct LoopState {
27 private:
28   enum Kind { Normal, Unrolled } K;
29   const Stmt *LoopStmt;
30   const LocationContext *LCtx;
31   LoopState(Kind InK, const Stmt *S, const LocationContext *L)
32       : K(InK), LoopStmt(S), LCtx(L) {}
33 
34 public:
35   static LoopState getNormal(const Stmt *S, const LocationContext *L) {
36     return LoopState(Normal, S, L);
37   }
38   static LoopState getUnrolled(const Stmt *S, const LocationContext *L) {
39     return LoopState(Unrolled, S, L);
40   }
41   bool isUnrolled() const { return K == Unrolled; }
42   const Stmt *getLoopStmt() const { return LoopStmt; }
43   const LocationContext *getLocationContext() const { return LCtx; }
44   bool operator==(const LoopState &X) const {
45     return K == X.K && LoopStmt == X.LoopStmt;
46   }
47   void Profile(llvm::FoldingSetNodeID &ID) const {
48     ID.AddInteger(K);
49     ID.AddPointer(LoopStmt);
50     ID.AddPointer(LCtx);
51   }
52 };
53 
54 // The tracked stack of loops. The stack indicates that which loops the
55 // simulated element contained by. The loops are marked depending if we decided
56 // to unroll them.
57 // TODO: The loop stack should not need to be in the program state since it is
58 // lexical in nature. Instead, the stack of loops should be tracked in the
59 // LocationContext.
60 REGISTER_LIST_WITH_PROGRAMSTATE(LoopStack, LoopState)
61 
62 namespace clang {
63 namespace ento {
64 
65 static bool isLoopStmt(const Stmt *S) {
66   return S && (isa<ForStmt>(S) || isa<WhileStmt>(S) || isa<DoStmt>(S));
67 }
68 
69 ProgramStateRef processLoopEnd(const Stmt *LoopStmt, ProgramStateRef State) {
70   auto LS = State->get<LoopStack>();
71   if (!LS.isEmpty() && LS.getHead().getLoopStmt() == LoopStmt)
72     State = State->set<LoopStack>(LS.getTail());
73   return State;
74 }
75 
76 static internal::Matcher<Stmt> simpleCondition(StringRef BindName) {
77   return binaryOperator(
78       anyOf(hasOperatorName("<"), hasOperatorName(">"), hasOperatorName("<="),
79             hasOperatorName(">="), hasOperatorName("!=")),
80       hasEitherOperand(ignoringParenImpCasts(
81           declRefExpr(to(varDecl(hasType(isInteger())).bind(BindName))))),
82       hasEitherOperand(ignoringParenImpCasts(integerLiteral())));
83 }
84 
85 static internal::Matcher<Stmt>
86 changeIntBoundNode(internal::Matcher<Decl> VarNodeMatcher) {
87   return anyOf(
88       unaryOperator(anyOf(hasOperatorName("--"), hasOperatorName("++")),
89                     hasUnaryOperand(ignoringParenImpCasts(
90                         declRefExpr(to(varDecl(VarNodeMatcher)))))),
91       binaryOperator(anyOf(hasOperatorName("="), hasOperatorName("+="),
92                            hasOperatorName("/="), hasOperatorName("*="),
93                            hasOperatorName("-=")),
94                      hasLHS(ignoringParenImpCasts(
95                          declRefExpr(to(varDecl(VarNodeMatcher)))))));
96 }
97 
98 static internal::Matcher<Stmt>
99 callByRef(internal::Matcher<Decl> VarNodeMatcher) {
100   return callExpr(forEachArgumentWithParam(
101       declRefExpr(to(varDecl(VarNodeMatcher))),
102       parmVarDecl(hasType(references(qualType(unless(isConstQualified())))))));
103 }
104 
105 static internal::Matcher<Stmt>
106 assignedToRef(internal::Matcher<Decl> VarNodeMatcher) {
107   return declStmt(hasDescendant(varDecl(
108       allOf(hasType(referenceType()),
109             hasInitializer(anyOf(
110                 initListExpr(has(declRefExpr(to(varDecl(VarNodeMatcher))))),
111                 declRefExpr(to(varDecl(VarNodeMatcher)))))))));
112 }
113 
114 static internal::Matcher<Stmt>
115 getAddrTo(internal::Matcher<Decl> VarNodeMatcher) {
116   return unaryOperator(
117       hasOperatorName("&"),
118       hasUnaryOperand(declRefExpr(hasDeclaration(VarNodeMatcher))));
119 }
120 
121 static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) {
122   return hasDescendant(stmt(
123       anyOf(gotoStmt(), switchStmt(), returnStmt(),
124             // Escaping and not known mutation of the loop counter is handled
125             // by exclusion of assigning and address-of operators and
126             // pass-by-ref function calls on the loop counter from the body.
127             changeIntBoundNode(equalsBoundNode(NodeName)),
128             callByRef(equalsBoundNode(NodeName)),
129             getAddrTo(equalsBoundNode(NodeName)),
130             assignedToRef(equalsBoundNode(NodeName)))));
131 }
132 
133 static internal::Matcher<Stmt> forLoopMatcher() {
134   return forStmt(
135              hasCondition(simpleCondition("initVarName")),
136              // Initialization should match the form: 'int i = 6' or 'i = 42'.
137              hasLoopInit(
138                  anyOf(declStmt(hasSingleDecl(
139                            varDecl(allOf(hasInitializer(integerLiteral()),
140                                          equalsBoundNode("initVarName"))))),
141                        binaryOperator(hasLHS(declRefExpr(to(varDecl(
142                                           equalsBoundNode("initVarName"))))),
143                                       hasRHS(integerLiteral())))),
144              // Incrementation should be a simple increment or decrement
145              // operator call.
146              hasIncrement(unaryOperator(
147                  anyOf(hasOperatorName("++"), hasOperatorName("--")),
148                  hasUnaryOperand(declRefExpr(
149                      to(varDecl(allOf(equalsBoundNode("initVarName"),
150                                       hasType(isInteger())))))))),
151              unless(hasBody(hasSuspiciousStmt("initVarName")))).bind("forLoop");
152 }
153 
154 static bool isPossiblyEscaped(const VarDecl *VD, ExplodedNode *N) {
155   // Global variables assumed as escaped variables.
156   if (VD->hasGlobalStorage())
157     return true;
158 
159   while (!N->pred_empty()) {
160     const Stmt *S = PathDiagnosticLocation::getStmt(N);
161     if (!S) {
162       N = N->getFirstPred();
163       continue;
164     }
165 
166     if (const DeclStmt *DS = dyn_cast<DeclStmt>(S)) {
167       for (const Decl *D : DS->decls()) {
168         // Once we reach the declaration of the VD we can return.
169         if (D->getCanonicalDecl() == VD)
170           return false;
171       }
172     }
173     // Check the usage of the pass-by-ref function calls and adress-of operator
174     // on VD and reference initialized by VD.
175     ASTContext &ASTCtx =
176         N->getLocationContext()->getAnalysisDeclContext()->getASTContext();
177     auto Match =
178         match(stmt(anyOf(callByRef(equalsNode(VD)), getAddrTo(equalsNode(VD)),
179                          assignedToRef(equalsNode(VD)))),
180               *S, ASTCtx);
181     if (!Match.empty())
182       return true;
183 
184     N = N->getFirstPred();
185   }
186   llvm_unreachable("Reached root without finding the declaration of VD");
187 }
188 
189 bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx,
190                             ExplodedNode *Pred) {
191 
192   if (!isLoopStmt(LoopStmt))
193     return false;
194 
195   // TODO: Match the cases where the bound is not a concrete literal but an
196   // integer with known value
197   auto Matches = match(forLoopMatcher(), *LoopStmt, ASTCtx);
198   if (Matches.empty())
199     return false;
200 
201   auto CounterVar = Matches[0].getNodeAs<VarDecl>("initVarName");
202 
203   // Check if the counter of the loop is not escaped before.
204   return !isPossiblyEscaped(CounterVar->getCanonicalDecl(), Pred);
205 }
206 
207 bool madeNewBranch(ExplodedNode* N, const Stmt* LoopStmt) {
208   const Stmt* S = nullptr;
209   while (!N->pred_empty())
210   {
211     if (N->succ_size() > 1)
212       return true;
213 
214     ProgramPoint P = N->getLocation();
215     if (Optional<BlockEntrance> BE = P.getAs<BlockEntrance>())
216       S = BE->getBlock()->getTerminator();
217 
218     if (S == LoopStmt)
219       return false;
220 
221     N = N->getFirstPred();
222   }
223 
224   llvm_unreachable("Reached root without encountering the previous step");
225 }
226 
227 // updateLoopStack is called on every basic block, therefore it needs to be fast
228 ProgramStateRef updateLoopStack(const Stmt *LoopStmt, ASTContext &ASTCtx,
229                                 ExplodedNode* Pred) {
230   auto State = Pred->getState();
231   auto LCtx = Pred->getLocationContext();
232 
233   if (!isLoopStmt(LoopStmt))
234     return State;
235 
236   auto LS = State->get<LoopStack>();
237   if (!LS.isEmpty() && LoopStmt == LS.getHead().getLoopStmt() &&
238       LCtx == LS.getHead().getLocationContext()) {
239     if (LS.getHead().isUnrolled() && madeNewBranch(Pred, LoopStmt)) {
240       State = State->set<LoopStack>(LS.getTail());
241       State = State->add<LoopStack>(LoopState::getNormal(LoopStmt, LCtx));
242     }
243     return State;
244   }
245 
246   if (!shouldCompletelyUnroll(LoopStmt, ASTCtx, Pred)) {
247     State = State->add<LoopStack>(LoopState::getNormal(LoopStmt, LCtx));
248     return State;
249   }
250 
251   State = State->add<LoopStack>(LoopState::getUnrolled(LoopStmt, LCtx));
252   return State;
253 }
254 
255 bool isUnrolledState(ProgramStateRef State) {
256   auto LS = State->get<LoopStack>();
257   if (LS.isEmpty() || !LS.getHead().isUnrolled())
258     return false;
259   return true;
260 }
261 }
262 }
263