1 //===--- LoopUnrolling.cpp - Unroll loops -----------------------*- C++ -*-===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 ///
10 /// This file contains functions which are used to decide if a loop worth to be
11 /// unrolled. Moreover contains function which mark the CFGBlocks which belongs
12 /// to the unrolled loop and store them in ProgramState.
13 ///
14 //===----------------------------------------------------------------------===//
15 
16 #include "clang/Analysis/CFGStmtMap.h"
17 #include "clang/ASTMatchers/ASTMatchers.h"
18 #include "clang/ASTMatchers/ASTMatchFinder.h"
19 #include "clang/AST/ParentMap.h"
20 #include "clang/AST/StmtVisitor.h"
21 #include "clang/StaticAnalyzer/Core/PathSensitive/CallEvent.h"
22 #include "clang/StaticAnalyzer/Core/PathSensitive/CheckerContext.h"
23 #include "clang/StaticAnalyzer/Core/PathSensitive/LoopUnrolling.h"
24 #include "llvm/ADT/Statistic.h"
25 
26 using namespace clang;
27 using namespace ento;
28 using namespace clang::ast_matchers;
29 
30 #define DEBUG_TYPE "LoopUnrolling"
31 
32 STATISTIC(NumTimesLoopUnrolled,
33           "The # of times a loop has got completely unrolled");
34 
35 REGISTER_MAP_WITH_PROGRAMSTATE(UnrolledLoops, const Stmt *,
36                                const FunctionDecl *)
37 
38 namespace clang {
39 namespace ento {
40 
41 static bool isLoopStmt(const Stmt *S) {
42   return S && (isa<ForStmt>(S) || isa<WhileStmt>(S) || isa<DoStmt>(S));
43 }
44 
45 static internal::Matcher<Stmt> simpleCondition(StringRef BindName) {
46   return binaryOperator(
47       anyOf(hasOperatorName("<"), hasOperatorName(">"), hasOperatorName("<="),
48             hasOperatorName(">="), hasOperatorName("!=")),
49       hasEitherOperand(ignoringParenImpCasts(
50           declRefExpr(to(varDecl(hasType(isInteger())).bind(BindName))))),
51       hasEitherOperand(ignoringParenImpCasts(integerLiteral())));
52 }
53 
54 static internal::Matcher<Stmt> changeIntBoundNode(StringRef NodeName) {
55   return anyOf(hasDescendant(unaryOperator(
56                    anyOf(hasOperatorName("--"), hasOperatorName("++")),
57                    hasUnaryOperand(ignoringParenImpCasts(
58                        declRefExpr(to(varDecl(equalsBoundNode(NodeName)))))))),
59                hasDescendant(binaryOperator(
60                    anyOf(hasOperatorName("="), hasOperatorName("+="),
61                          hasOperatorName("/="), hasOperatorName("*="),
62                          hasOperatorName("-=")),
63                    hasLHS(ignoringParenImpCasts(
64                        declRefExpr(to(varDecl(equalsBoundNode(NodeName)))))))));
65 }
66 
67 static internal::Matcher<Stmt> callByRef(StringRef NodeName) {
68   return hasDescendant(callExpr(forEachArgumentWithParam(
69       declRefExpr(to(varDecl(equalsBoundNode(NodeName)))),
70       parmVarDecl(hasType(references(qualType(unless(isConstQualified()))))))));
71 }
72 
73 static internal::Matcher<Stmt> assignedToRef(StringRef NodeName) {
74   return hasDescendant(varDecl(
75       allOf(hasType(referenceType()),
76             hasInitializer(
77                 anyOf(initListExpr(has(
78                           declRefExpr(to(varDecl(equalsBoundNode(NodeName)))))),
79                       declRefExpr(to(varDecl(equalsBoundNode(NodeName)))))))));
80 }
81 
82 static internal::Matcher<Stmt> getAddrTo(StringRef NodeName) {
83   return hasDescendant(unaryOperator(
84       hasOperatorName("&"),
85       hasUnaryOperand(declRefExpr(hasDeclaration(equalsBoundNode(NodeName))))));
86 }
87 
88 static internal::Matcher<Stmt> hasSuspiciousStmt(StringRef NodeName) {
89   return anyOf(hasDescendant(gotoStmt()), hasDescendant(switchStmt()),
90                // Escaping and not known mutation of the loop counter is handled
91                // by exclusion of assigning and address-of operators and
92                // pass-by-ref function calls on the loop counter from the body.
93                changeIntBoundNode(NodeName), callByRef(NodeName),
94                getAddrTo(NodeName), assignedToRef(NodeName));
95 }
96 
97 static internal::Matcher<Stmt> forLoopMatcher() {
98   return forStmt(
99              hasCondition(simpleCondition("initVarName")),
100              // Initialization should match the form: 'int i = 6' or 'i = 42'.
101              hasLoopInit(
102                  anyOf(declStmt(hasSingleDecl(
103                            varDecl(allOf(hasInitializer(integerLiteral()),
104                                          equalsBoundNode("initVarName"))))),
105                        binaryOperator(hasLHS(declRefExpr(to(varDecl(
106                                           equalsBoundNode("initVarName"))))),
107                                       hasRHS(integerLiteral())))),
108              // Incrementation should be a simple increment or decrement
109              // operator call.
110              hasIncrement(unaryOperator(
111                  anyOf(hasOperatorName("++"), hasOperatorName("--")),
112                  hasUnaryOperand(declRefExpr(
113                      to(varDecl(allOf(equalsBoundNode("initVarName"),
114                                       hasType(isInteger())))))))),
115              unless(hasBody(hasSuspiciousStmt("initVarName")))).bind("forLoop");
116 }
117 
118 bool shouldCompletelyUnroll(const Stmt *LoopStmt, ASTContext &ASTCtx) {
119 
120   if (!isLoopStmt(LoopStmt))
121     return false;
122 
123   // TODO: Match the cases where the bound is not a concrete literal but an
124   // integer with known value
125 
126   auto Matches = match(forLoopMatcher(), *LoopStmt, ASTCtx);
127   return !Matches.empty();
128 }
129 
130 namespace {
131 class LoopBlockVisitor : public ConstStmtVisitor<LoopBlockVisitor> {
132 public:
133   LoopBlockVisitor(llvm::SmallPtrSet<const CFGBlock *, 8> &BS) : BlockSet(BS) {}
134 
135   void VisitChildren(const Stmt *S) {
136     for (const Stmt *Child : S->children())
137       if (Child)
138         Visit(Child);
139   }
140 
141   void VisitStmt(const Stmt *S) {
142     // In case of nested loops we only unroll the inner loop if it's marked too.
143     if (!S || (isLoopStmt(S) && S != LoopStmt))
144       return;
145     BlockSet.insert(StmtToBlockMap->getBlock(S));
146     VisitChildren(S);
147   }
148 
149   void setBlocksOfLoop(const Stmt *Loop, const CFGStmtMap *M) {
150     BlockSet.clear();
151     StmtToBlockMap = M;
152     LoopStmt = Loop;
153     Visit(LoopStmt);
154   }
155 
156 private:
157   llvm::SmallPtrSet<const CFGBlock *, 8> &BlockSet;
158   const CFGStmtMap *StmtToBlockMap;
159   const Stmt *LoopStmt;
160 };
161 }
162 // TODO: refactor this function using LoopExit CFG element - once we have the
163 // information when the simulation reaches the end of the loop we can cleanup
164 // the state
165 bool isUnrolledLoopBlock(const CFGBlock *Block, ExplodedNode *Pred,
166                          AnalysisManager &AMgr) {
167   const Stmt *Term = Block->getTerminator();
168   auto State = Pred->getState();
169   // In case of nested loops in an inlined function should not be unrolled only
170   // if the inner loop is marked.
171   if (Term && isLoopStmt(Term) && !State->contains<UnrolledLoops>(Term))
172     return false;
173 
174   const CFGBlock *SearchedBlock;
175   llvm::SmallPtrSet<const CFGBlock *, 8> BlockSet;
176   LoopBlockVisitor LBV(BlockSet);
177   // Check the CFGBlocks of every marked loop.
178   for (auto &E : State->get<UnrolledLoops>()) {
179     SearchedBlock = Block;
180     const StackFrameContext *StackFrame = Pred->getStackFrame();
181     ParentMap PM(E.second->getBody());
182     CFGStmtMap *M = CFGStmtMap::Build(AMgr.getCFG(E.second), &PM);
183     LBV.setBlocksOfLoop(E.first, M);
184     // In case of an inlined function call check if any of its callSiteBlock is
185     // marked.
186     while (BlockSet.find(SearchedBlock) == BlockSet.end() && !StackFrame->inTopFrame()) {
187       SearchedBlock = StackFrame->getCallSiteBlock();
188       if(!SearchedBlock)
189         break;
190       StackFrame = StackFrame->getParent()->getCurrentStackFrame();
191     }
192     delete M;
193     if (SearchedBlock)
194       return true;
195   }
196   return false;
197 }
198 
199 ProgramStateRef markLoopAsUnrolled(const Stmt *Term, ProgramStateRef State,
200                                    const FunctionDecl *FD) {
201   if (State->contains<UnrolledLoops>(Term))
202     return State;
203 
204   State = State->set<UnrolledLoops>(Term, FD);
205   ++NumTimesLoopUnrolled;
206   return State;
207 }
208 }
209 }
210