1 //===--- TransBlockObjCVariable.cpp - Tranformations to ARC mode ----------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // rewriteBlockObjCVariable:
11 //
12 // Adding __block to an obj-c variable could be either because the the variable
13 // is used for output storage or the user wanted to break a retain cycle.
14 // This transformation checks whether a reference of the variable for the block
15 // is actually needed (it is assigned to or its address is taken) or not.
16 // If the reference is not needed it will assume __block was added to break a
17 // cycle so it will remove '__block' and add __weak/__unsafe_unretained.
18 // e.g
19 //
20 //   __block Foo *x;
21 //   bar(^ { [x cake]; });
22 // ---->
23 //   __weak Foo *x;
24 //   bar(^ { [x cake]; });
25 //
26 //===----------------------------------------------------------------------===//
27 
28 #include "Transforms.h"
29 #include "Internals.h"
30 #include "clang/Basic/SourceManager.h"
31 
32 using namespace clang;
33 using namespace arcmt;
34 using namespace trans;
35 
36 namespace {
37 
38 class RootBlockObjCVarRewriter :
39                           public RecursiveASTVisitor<RootBlockObjCVarRewriter> {
40   llvm::DenseSet<VarDecl *> &VarsToChange;
41 
42   class BlockVarChecker : public RecursiveASTVisitor<BlockVarChecker> {
43     VarDecl *Var;
44 
45     typedef RecursiveASTVisitor<BlockVarChecker> base;
46   public:
47     BlockVarChecker(VarDecl *var) : Var(var) { }
48 
49     bool TraverseImplicitCastExpr(ImplicitCastExpr *castE) {
50       if (DeclRefExpr *
51             ref = dyn_cast<DeclRefExpr>(castE->getSubExpr())) {
52         if (ref->getDecl() == Var) {
53           if (castE->getCastKind() == CK_LValueToRValue)
54             return true; // Using the value of the variable.
55           if (castE->getCastKind() == CK_NoOp && castE->isLValue() &&
56               Var->getASTContext().getLangOpts().CPlusPlus)
57             return true; // Binding to const C++ reference.
58         }
59       }
60 
61       return base::TraverseImplicitCastExpr(castE);
62     }
63 
64     bool VisitDeclRefExpr(DeclRefExpr *E) {
65       if (E->getDecl() == Var)
66         return false; // The reference of the variable, and not just its value,
67                       //  is needed.
68       return true;
69     }
70   };
71 
72 public:
73   RootBlockObjCVarRewriter(llvm::DenseSet<VarDecl *> &VarsToChange)
74     : VarsToChange(VarsToChange) { }
75 
76   bool VisitBlockDecl(BlockDecl *block) {
77     SmallVector<VarDecl *, 4> BlockVars;
78 
79     for (BlockDecl::capture_iterator
80            I = block->capture_begin(), E = block->capture_end(); I != E; ++I) {
81       VarDecl *var = I->getVariable();
82       if (I->isByRef() &&
83           var->getType()->isObjCObjectPointerType() &&
84           isImplicitStrong(var->getType())) {
85         BlockVars.push_back(var);
86       }
87     }
88 
89     for (unsigned i = 0, e = BlockVars.size(); i != e; ++i) {
90       VarDecl *var = BlockVars[i];
91 
92       BlockVarChecker checker(var);
93       bool onlyValueOfVarIsNeeded = checker.TraverseStmt(block->getBody());
94       if (onlyValueOfVarIsNeeded)
95         VarsToChange.insert(var);
96       else
97         VarsToChange.erase(var);
98     }
99 
100     return true;
101   }
102 
103 private:
104   bool isImplicitStrong(QualType ty) {
105     if (isa<AttributedType>(ty.getTypePtr()))
106       return false;
107     return ty.getLocalQualifiers().getObjCLifetime() == Qualifiers::OCL_Strong;
108   }
109 };
110 
111 class BlockObjCVarRewriter : public RecursiveASTVisitor<BlockObjCVarRewriter> {
112   MigrationPass &Pass;
113   llvm::DenseSet<VarDecl *> &VarsToChange;
114 
115 public:
116   BlockObjCVarRewriter(MigrationPass &pass,
117                        llvm::DenseSet<VarDecl *> &VarsToChange)
118     : Pass(pass), VarsToChange(VarsToChange) { }
119 
120   bool TraverseBlockDecl(BlockDecl *block) {
121     RootBlockObjCVarRewriter(VarsToChange).TraverseDecl(block);
122     return true;
123   }
124 };
125 
126 } // anonymous namespace
127 
128 void BlockObjCVariableTraverser::traverseBody(BodyContext &BodyCtx) {
129   MigrationPass &Pass = BodyCtx.getMigrationContext().Pass;
130   llvm::DenseSet<VarDecl *> VarsToChange;
131 
132   BlockObjCVarRewriter trans(Pass, VarsToChange);
133   trans.TraverseStmt(BodyCtx.getTopStmt());
134 
135   for (llvm::DenseSet<VarDecl *>::iterator
136          I = VarsToChange.begin(), E = VarsToChange.end(); I != E; ++I) {
137     VarDecl *var = *I;
138     BlocksAttr *attr = var->getAttr<BlocksAttr>();
139     if(!attr)
140       continue;
141     bool useWeak = canApplyWeak(Pass.Ctx, var->getType());
142     SourceManager &SM = Pass.Ctx.getSourceManager();
143     Transaction Trans(Pass.TA);
144     Pass.TA.replaceText(SM.getExpansionLoc(attr->getLocation()),
145                         "__block",
146                         useWeak ? "__weak" : "__unsafe_unretained");
147   }
148 }
149