1 //===--- Tranforms.cpp - Tranformations to ARC mode -----------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #include "Transforms.h"
11 #include "Internals.h"
12 #include "clang/Analysis/DomainSpecific/CocoaConventions.h"
13 #include "clang/AST/ASTContext.h"
14 #include "clang/AST/RecursiveASTVisitor.h"
15 #include "clang/AST/StmtVisitor.h"
16 #include "clang/Basic/SourceManager.h"
17 #include "clang/Lex/Lexer.h"
18 #include "clang/Sema/Sema.h"
19 #include "clang/Sema/SemaDiagnostic.h"
20 #include "llvm/ADT/StringSwitch.h"
21 #include "llvm/ADT/DenseSet.h"
22 #include <map>
23 
24 using namespace clang;
25 using namespace arcmt;
26 using namespace trans;
27 
28 ASTTraverser::~ASTTraverser() { }
29 
30 bool MigrationPass::CFBridgingFunctionsDefined() {
31   if (!EnableCFBridgeFns.hasValue())
32     EnableCFBridgeFns = SemaRef.isKnownName("CFBridgingRetain") &&
33                         SemaRef.isKnownName("CFBridgingRelease");
34   return *EnableCFBridgeFns;
35 }
36 
37 //===----------------------------------------------------------------------===//
38 // Helpers.
39 //===----------------------------------------------------------------------===//
40 
41 bool trans::canApplyWeak(ASTContext &Ctx, QualType type,
42                          bool AllowOnUnknownClass) {
43   if (!Ctx.getLangOpts().ObjCRuntimeHasWeak)
44     return false;
45 
46   QualType T = type;
47   if (T.isNull())
48     return false;
49 
50   // iOS is always safe to use 'weak'.
51   if (Ctx.getTargetInfo().getTriple().getOS() == llvm::Triple::IOS)
52     AllowOnUnknownClass = true;
53 
54   while (const PointerType *ptr = T->getAs<PointerType>())
55     T = ptr->getPointeeType();
56   if (const ObjCObjectPointerType *ObjT = T->getAs<ObjCObjectPointerType>()) {
57     ObjCInterfaceDecl *Class = ObjT->getInterfaceDecl();
58     if (!AllowOnUnknownClass && (!Class || Class->getName() == "NSObject"))
59       return false; // id/NSObject is not safe for weak.
60     if (!AllowOnUnknownClass && !Class->hasDefinition())
61       return false; // forward classes are not verifiable, therefore not safe.
62     if (Class->isArcWeakrefUnavailable())
63       return false;
64   }
65 
66   return true;
67 }
68 
69 bool trans::isPlusOneAssign(const BinaryOperator *E) {
70   if (E->getOpcode() != BO_Assign)
71     return false;
72 
73   if (const ObjCMessageExpr *
74         ME = dyn_cast<ObjCMessageExpr>(E->getRHS()->IgnoreParenCasts()))
75     if (ME->getMethodFamily() == OMF_retain)
76       return true;
77 
78   if (const CallExpr *
79         callE = dyn_cast<CallExpr>(E->getRHS()->IgnoreParenCasts())) {
80     if (const FunctionDecl *FD = callE->getDirectCallee()) {
81       if (FD->getAttr<CFReturnsRetainedAttr>())
82         return true;
83 
84       if (FD->isGlobal() &&
85           FD->getIdentifier() &&
86           FD->getParent()->isTranslationUnit() &&
87           FD->getLinkage() == ExternalLinkage &&
88           ento::cocoa::isRefType(callE->getType(), "CF",
89                                  FD->getIdentifier()->getName())) {
90         StringRef fname = FD->getIdentifier()->getName();
91         if (fname.endswith("Retain") ||
92             fname.find("Create") != StringRef::npos ||
93             fname.find("Copy") != StringRef::npos) {
94           return true;
95         }
96       }
97     }
98   }
99 
100   const ImplicitCastExpr *implCE = dyn_cast<ImplicitCastExpr>(E->getRHS());
101   while (implCE && implCE->getCastKind() ==  CK_BitCast)
102     implCE = dyn_cast<ImplicitCastExpr>(implCE->getSubExpr());
103 
104   if (implCE && implCE->getCastKind() == CK_ARCConsumeObject)
105     return true;
106 
107   return false;
108 }
109 
110 /// \brief 'Loc' is the end of a statement range. This returns the location
111 /// immediately after the semicolon following the statement.
112 /// If no semicolon is found or the location is inside a macro, the returned
113 /// source location will be invalid.
114 SourceLocation trans::findLocationAfterSemi(SourceLocation loc,
115                                             ASTContext &Ctx) {
116   SourceLocation SemiLoc = findSemiAfterLocation(loc, Ctx);
117   if (SemiLoc.isInvalid())
118     return SourceLocation();
119   return SemiLoc.getLocWithOffset(1);
120 }
121 
122 /// \brief \arg Loc is the end of a statement range. This returns the location
123 /// of the semicolon following the statement.
124 /// If no semicolon is found or the location is inside a macro, the returned
125 /// source location will be invalid.
126 SourceLocation trans::findSemiAfterLocation(SourceLocation loc,
127                                             ASTContext &Ctx) {
128   SourceManager &SM = Ctx.getSourceManager();
129   if (loc.isMacroID()) {
130     if (!Lexer::isAtEndOfMacroExpansion(loc, SM, Ctx.getLangOpts(), &loc))
131       return SourceLocation();
132   }
133   loc = Lexer::getLocForEndOfToken(loc, /*Offset=*/0, SM, Ctx.getLangOpts());
134 
135   // Break down the source location.
136   std::pair<FileID, unsigned> locInfo = SM.getDecomposedLoc(loc);
137 
138   // Try to load the file buffer.
139   bool invalidTemp = false;
140   StringRef file = SM.getBufferData(locInfo.first, &invalidTemp);
141   if (invalidTemp)
142     return SourceLocation();
143 
144   const char *tokenBegin = file.data() + locInfo.second;
145 
146   // Lex from the start of the given location.
147   Lexer lexer(SM.getLocForStartOfFile(locInfo.first),
148               Ctx.getLangOpts(),
149               file.begin(), tokenBegin, file.end());
150   Token tok;
151   lexer.LexFromRawLexer(tok);
152   if (tok.isNot(tok::semi))
153     return SourceLocation();
154 
155   return tok.getLocation();
156 }
157 
158 bool trans::hasSideEffects(Expr *E, ASTContext &Ctx) {
159   if (!E || !E->HasSideEffects(Ctx))
160     return false;
161 
162   E = E->IgnoreParenCasts();
163   ObjCMessageExpr *ME = dyn_cast<ObjCMessageExpr>(E);
164   if (!ME)
165     return true;
166   switch (ME->getMethodFamily()) {
167   case OMF_autorelease:
168   case OMF_dealloc:
169   case OMF_release:
170   case OMF_retain:
171     switch (ME->getReceiverKind()) {
172     case ObjCMessageExpr::SuperInstance:
173       return false;
174     case ObjCMessageExpr::Instance:
175       return hasSideEffects(ME->getInstanceReceiver(), Ctx);
176     default:
177       break;
178     }
179     break;
180   default:
181     break;
182   }
183 
184   return true;
185 }
186 
187 bool trans::isGlobalVar(Expr *E) {
188   E = E->IgnoreParenCasts();
189   if (DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(E))
190     return DRE->getDecl()->getDeclContext()->isFileContext() &&
191            DRE->getDecl()->getLinkage() == ExternalLinkage;
192   if (ConditionalOperator *condOp = dyn_cast<ConditionalOperator>(E))
193     return isGlobalVar(condOp->getTrueExpr()) &&
194            isGlobalVar(condOp->getFalseExpr());
195 
196   return false;
197 }
198 
199 StringRef trans::getNilString(ASTContext &Ctx) {
200   if (Ctx.Idents.get("nil").hasMacroDefinition())
201     return "nil";
202   else
203     return "0";
204 }
205 
206 namespace {
207 
208 class ReferenceClear : public RecursiveASTVisitor<ReferenceClear> {
209   ExprSet &Refs;
210 public:
211   ReferenceClear(ExprSet &refs) : Refs(refs) { }
212   bool VisitDeclRefExpr(DeclRefExpr *E) { Refs.erase(E); return true; }
213 };
214 
215 class ReferenceCollector : public RecursiveASTVisitor<ReferenceCollector> {
216   ValueDecl *Dcl;
217   ExprSet &Refs;
218 
219 public:
220   ReferenceCollector(ValueDecl *D, ExprSet &refs)
221     : Dcl(D), Refs(refs) { }
222 
223   bool VisitDeclRefExpr(DeclRefExpr *E) {
224     if (E->getDecl() == Dcl)
225       Refs.insert(E);
226     return true;
227   }
228 };
229 
230 class RemovablesCollector : public RecursiveASTVisitor<RemovablesCollector> {
231   ExprSet &Removables;
232 
233 public:
234   RemovablesCollector(ExprSet &removables)
235   : Removables(removables) { }
236 
237   bool shouldWalkTypesOfTypeLocs() const { return false; }
238 
239   bool TraverseStmtExpr(StmtExpr *E) {
240     CompoundStmt *S = E->getSubStmt();
241     for (CompoundStmt::body_iterator
242         I = S->body_begin(), E = S->body_end(); I != E; ++I) {
243       if (I != E - 1)
244         mark(*I);
245       TraverseStmt(*I);
246     }
247     return true;
248   }
249 
250   bool VisitCompoundStmt(CompoundStmt *S) {
251     for (CompoundStmt::body_iterator
252         I = S->body_begin(), E = S->body_end(); I != E; ++I)
253       mark(*I);
254     return true;
255   }
256 
257   bool VisitIfStmt(IfStmt *S) {
258     mark(S->getThen());
259     mark(S->getElse());
260     return true;
261   }
262 
263   bool VisitWhileStmt(WhileStmt *S) {
264     mark(S->getBody());
265     return true;
266   }
267 
268   bool VisitDoStmt(DoStmt *S) {
269     mark(S->getBody());
270     return true;
271   }
272 
273   bool VisitForStmt(ForStmt *S) {
274     mark(S->getInit());
275     mark(S->getInc());
276     mark(S->getBody());
277     return true;
278   }
279 
280 private:
281   void mark(Stmt *S) {
282     if (!S) return;
283 
284     while (LabelStmt *Label = dyn_cast<LabelStmt>(S))
285       S = Label->getSubStmt();
286     S = S->IgnoreImplicit();
287     if (Expr *E = dyn_cast<Expr>(S))
288       Removables.insert(E);
289   }
290 };
291 
292 } // end anonymous namespace
293 
294 void trans::clearRefsIn(Stmt *S, ExprSet &refs) {
295   ReferenceClear(refs).TraverseStmt(S);
296 }
297 
298 void trans::collectRefs(ValueDecl *D, Stmt *S, ExprSet &refs) {
299   ReferenceCollector(D, refs).TraverseStmt(S);
300 }
301 
302 void trans::collectRemovables(Stmt *S, ExprSet &exprs) {
303   RemovablesCollector(exprs).TraverseStmt(S);
304 }
305 
306 //===----------------------------------------------------------------------===//
307 // MigrationContext
308 //===----------------------------------------------------------------------===//
309 
310 namespace {
311 
312 class ASTTransform : public RecursiveASTVisitor<ASTTransform> {
313   MigrationContext &MigrateCtx;
314   typedef RecursiveASTVisitor<ASTTransform> base;
315 
316 public:
317   ASTTransform(MigrationContext &MigrateCtx) : MigrateCtx(MigrateCtx) { }
318 
319   bool shouldWalkTypesOfTypeLocs() const { return false; }
320 
321   bool TraverseObjCImplementationDecl(ObjCImplementationDecl *D) {
322     ObjCImplementationContext ImplCtx(MigrateCtx, D);
323     for (MigrationContext::traverser_iterator
324            I = MigrateCtx.traversers_begin(),
325            E = MigrateCtx.traversers_end(); I != E; ++I)
326       (*I)->traverseObjCImplementation(ImplCtx);
327 
328     return base::TraverseObjCImplementationDecl(D);
329   }
330 
331   bool TraverseStmt(Stmt *rootS) {
332     if (!rootS)
333       return true;
334 
335     BodyContext BodyCtx(MigrateCtx, rootS);
336     for (MigrationContext::traverser_iterator
337            I = MigrateCtx.traversers_begin(),
338            E = MigrateCtx.traversers_end(); I != E; ++I)
339       (*I)->traverseBody(BodyCtx);
340 
341     return true;
342   }
343 };
344 
345 }
346 
347 MigrationContext::~MigrationContext() {
348   for (traverser_iterator
349          I = traversers_begin(), E = traversers_end(); I != E; ++I)
350     delete *I;
351 }
352 
353 bool MigrationContext::isGCOwnedNonObjC(QualType T) {
354   while (!T.isNull()) {
355     if (const AttributedType *AttrT = T->getAs<AttributedType>()) {
356       if (AttrT->getAttrKind() == AttributedType::attr_objc_ownership)
357         return !AttrT->getModifiedType()->isObjCRetainableType();
358     }
359 
360     if (T->isArrayType())
361       T = Pass.Ctx.getBaseElementType(T);
362     else if (const PointerType *PT = T->getAs<PointerType>())
363       T = PT->getPointeeType();
364     else if (const ReferenceType *RT = T->getAs<ReferenceType>())
365       T = RT->getPointeeType();
366     else
367       break;
368   }
369 
370   return false;
371 }
372 
373 bool MigrationContext::rewritePropertyAttribute(StringRef fromAttr,
374                                                 StringRef toAttr,
375                                                 SourceLocation atLoc) {
376   if (atLoc.isMacroID())
377     return false;
378 
379   SourceManager &SM = Pass.Ctx.getSourceManager();
380 
381   // Break down the source location.
382   std::pair<FileID, unsigned> locInfo = SM.getDecomposedLoc(atLoc);
383 
384   // Try to load the file buffer.
385   bool invalidTemp = false;
386   StringRef file = SM.getBufferData(locInfo.first, &invalidTemp);
387   if (invalidTemp)
388     return false;
389 
390   const char *tokenBegin = file.data() + locInfo.second;
391 
392   // Lex from the start of the given location.
393   Lexer lexer(SM.getLocForStartOfFile(locInfo.first),
394               Pass.Ctx.getLangOpts(),
395               file.begin(), tokenBegin, file.end());
396   Token tok;
397   lexer.LexFromRawLexer(tok);
398   if (tok.isNot(tok::at)) return false;
399   lexer.LexFromRawLexer(tok);
400   if (tok.isNot(tok::raw_identifier)) return false;
401   if (StringRef(tok.getRawIdentifierData(), tok.getLength())
402         != "property")
403     return false;
404   lexer.LexFromRawLexer(tok);
405   if (tok.isNot(tok::l_paren)) return false;
406 
407   Token BeforeTok = tok;
408   Token AfterTok;
409   AfterTok.startToken();
410   SourceLocation AttrLoc;
411 
412   lexer.LexFromRawLexer(tok);
413   if (tok.is(tok::r_paren))
414     return false;
415 
416   while (1) {
417     if (tok.isNot(tok::raw_identifier)) return false;
418     StringRef ident(tok.getRawIdentifierData(), tok.getLength());
419     if (ident == fromAttr) {
420       if (!toAttr.empty()) {
421         Pass.TA.replaceText(tok.getLocation(), fromAttr, toAttr);
422         return true;
423       }
424       // We want to remove the attribute.
425       AttrLoc = tok.getLocation();
426     }
427 
428     do {
429       lexer.LexFromRawLexer(tok);
430       if (AttrLoc.isValid() && AfterTok.is(tok::unknown))
431         AfterTok = tok;
432     } while (tok.isNot(tok::comma) && tok.isNot(tok::r_paren));
433     if (tok.is(tok::r_paren))
434       break;
435     if (AttrLoc.isInvalid())
436       BeforeTok = tok;
437     lexer.LexFromRawLexer(tok);
438   }
439 
440   if (toAttr.empty() && AttrLoc.isValid() && AfterTok.isNot(tok::unknown)) {
441     // We want to remove the attribute.
442     if (BeforeTok.is(tok::l_paren) && AfterTok.is(tok::r_paren)) {
443       Pass.TA.remove(SourceRange(BeforeTok.getLocation(),
444                                  AfterTok.getLocation()));
445     } else if (BeforeTok.is(tok::l_paren) && AfterTok.is(tok::comma)) {
446       Pass.TA.remove(SourceRange(AttrLoc, AfterTok.getLocation()));
447     } else {
448       Pass.TA.remove(SourceRange(BeforeTok.getLocation(), AttrLoc));
449     }
450 
451     return true;
452   }
453 
454   return false;
455 }
456 
457 bool MigrationContext::addPropertyAttribute(StringRef attr,
458                                             SourceLocation atLoc) {
459   if (atLoc.isMacroID())
460     return false;
461 
462   SourceManager &SM = Pass.Ctx.getSourceManager();
463 
464   // Break down the source location.
465   std::pair<FileID, unsigned> locInfo = SM.getDecomposedLoc(atLoc);
466 
467   // Try to load the file buffer.
468   bool invalidTemp = false;
469   StringRef file = SM.getBufferData(locInfo.first, &invalidTemp);
470   if (invalidTemp)
471     return false;
472 
473   const char *tokenBegin = file.data() + locInfo.second;
474 
475   // Lex from the start of the given location.
476   Lexer lexer(SM.getLocForStartOfFile(locInfo.first),
477               Pass.Ctx.getLangOpts(),
478               file.begin(), tokenBegin, file.end());
479   Token tok;
480   lexer.LexFromRawLexer(tok);
481   if (tok.isNot(tok::at)) return false;
482   lexer.LexFromRawLexer(tok);
483   if (tok.isNot(tok::raw_identifier)) return false;
484   if (StringRef(tok.getRawIdentifierData(), tok.getLength())
485         != "property")
486     return false;
487   lexer.LexFromRawLexer(tok);
488 
489   if (tok.isNot(tok::l_paren)) {
490     Pass.TA.insert(tok.getLocation(), std::string("(") + attr.str() + ") ");
491     return true;
492   }
493 
494   lexer.LexFromRawLexer(tok);
495   if (tok.is(tok::r_paren)) {
496     Pass.TA.insert(tok.getLocation(), attr);
497     return true;
498   }
499 
500   if (tok.isNot(tok::raw_identifier)) return false;
501 
502   Pass.TA.insert(tok.getLocation(), std::string(attr) + ", ");
503   return true;
504 }
505 
506 void MigrationContext::traverse(TranslationUnitDecl *TU) {
507   for (traverser_iterator
508          I = traversers_begin(), E = traversers_end(); I != E; ++I)
509     (*I)->traverseTU(*this);
510 
511   ASTTransform(*this).TraverseDecl(TU);
512 }
513 
514 static void GCRewriteFinalize(MigrationPass &pass) {
515   ASTContext &Ctx = pass.Ctx;
516   TransformActions &TA = pass.TA;
517   DeclContext *DC = Ctx.getTranslationUnitDecl();
518   Selector FinalizeSel =
519    Ctx.Selectors.getNullarySelector(&pass.Ctx.Idents.get("finalize"));
520 
521   typedef DeclContext::specific_decl_iterator<ObjCImplementationDecl>
522   impl_iterator;
523   for (impl_iterator I = impl_iterator(DC->decls_begin()),
524        E = impl_iterator(DC->decls_end()); I != E; ++I) {
525     for (ObjCImplementationDecl::instmeth_iterator
526          MI = I->instmeth_begin(),
527          ME = I->instmeth_end(); MI != ME; ++MI) {
528       ObjCMethodDecl *MD = *MI;
529       if (!MD->hasBody())
530         continue;
531 
532       if (MD->isInstanceMethod() && MD->getSelector() == FinalizeSel) {
533         ObjCMethodDecl *FinalizeM = MD;
534         Transaction Trans(TA);
535         TA.insert(FinalizeM->getSourceRange().getBegin(),
536                   "#if !__has_feature(objc_arc)\n");
537         CharSourceRange::getTokenRange(FinalizeM->getSourceRange());
538         const SourceManager &SM = pass.Ctx.getSourceManager();
539         const LangOptions &LangOpts = pass.Ctx.getLangOpts();
540         bool Invalid;
541         std::string str = "\n#endif\n";
542         str += Lexer::getSourceText(
543                   CharSourceRange::getTokenRange(FinalizeM->getSourceRange()),
544                                     SM, LangOpts, &Invalid);
545         TA.insertAfterToken(FinalizeM->getSourceRange().getEnd(), str);
546 
547         break;
548       }
549     }
550   }
551 }
552 
553 //===----------------------------------------------------------------------===//
554 // getAllTransformations.
555 //===----------------------------------------------------------------------===//
556 
557 static void traverseAST(MigrationPass &pass) {
558   MigrationContext MigrateCtx(pass);
559 
560   if (pass.isGCMigration()) {
561     MigrateCtx.addTraverser(new GCCollectableCallsTraverser);
562     MigrateCtx.addTraverser(new GCAttrsTraverser());
563   }
564   MigrateCtx.addTraverser(new PropertyRewriteTraverser());
565   MigrateCtx.addTraverser(new BlockObjCVariableTraverser());
566 
567   MigrateCtx.traverse(pass.Ctx.getTranslationUnitDecl());
568 }
569 
570 static void independentTransforms(MigrationPass &pass) {
571   rewriteAutoreleasePool(pass);
572   removeRetainReleaseDeallocFinalize(pass);
573   rewriteUnusedInitDelegate(pass);
574   removeZeroOutPropsInDeallocFinalize(pass);
575   makeAssignARCSafe(pass);
576   rewriteUnbridgedCasts(pass);
577   checkAPIUses(pass);
578   traverseAST(pass);
579 }
580 
581 std::vector<TransformFn> arcmt::getAllTransformations(
582                                                LangOptions::GCMode OrigGCMode,
583                                                bool NoFinalizeRemoval) {
584   std::vector<TransformFn> transforms;
585 
586   if (OrigGCMode ==  LangOptions::GCOnly && NoFinalizeRemoval)
587     transforms.push_back(GCRewriteFinalize);
588   transforms.push_back(independentTransforms);
589   // This depends on previous transformations removing various expressions.
590   transforms.push_back(removeEmptyStatementsAndDeallocFinalize);
591 
592   return transforms;
593 }
594