1 //===--- TransProtectedScope.cpp - Transformations to ARC mode ------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // Adds brackets in case statements that "contain" initialization of retaining
11 // variable, thus emitting the "switch case is in protected scope" error.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "Transforms.h"
16 #include "Internals.h"
17 #include "clang/Sema/SemaDiagnostic.h"
18 
19 using namespace clang;
20 using namespace arcmt;
21 using namespace trans;
22 
23 namespace {
24 
25 struct CaseInfo {
26   SwitchCase *SC;
27   SourceRange Range;
28   bool FixedBypass;
29 
30   CaseInfo() : SC(0), FixedBypass(false) {}
31   CaseInfo(SwitchCase *S, SourceRange Range)
32     : SC(S), Range(Range), FixedBypass(false) {}
33 };
34 
35 class CaseCollector : public RecursiveASTVisitor<CaseCollector> {
36   llvm::SmallVectorImpl<CaseInfo> &Cases;
37 
38 public:
39   CaseCollector(llvm::SmallVectorImpl<CaseInfo> &Cases)
40     : Cases(Cases) { }
41 
42   bool VisitSwitchStmt(SwitchStmt *S) {
43     SourceLocation NextLoc = S->getLocEnd();
44     SwitchCase *Curr = S->getSwitchCaseList();
45     // We iterate over case statements in reverse source-order.
46     while (Curr) {
47       Cases.push_back(CaseInfo(Curr,SourceRange(Curr->getLocStart(), NextLoc)));
48       NextLoc = Curr->getLocStart();
49       Curr = Curr->getNextSwitchCase();
50     }
51     return true;
52   }
53 };
54 
55 } // anonymous namespace
56 
57 static bool isInRange(FullSourceLoc Loc, SourceRange R) {
58   return !Loc.isBeforeInTranslationUnitThan(R.getBegin()) &&
59           Loc.isBeforeInTranslationUnitThan(R.getEnd());
60 }
61 
62 static bool handleProtectedNote(const StoredDiagnostic &Diag,
63                                 llvm::SmallVectorImpl<CaseInfo> &Cases,
64                                 TransformActions &TA) {
65   assert(Diag.getLevel() == DiagnosticsEngine::Note);
66 
67   for (unsigned i = 0; i != Cases.size(); i++) {
68     CaseInfo &info = Cases[i];
69     if (isInRange(Diag.getLocation(), info.Range)) {
70       TA.clearDiagnostic(Diag.getID(), Diag.getLocation());
71       if (!info.FixedBypass) {
72         TA.insertAfterToken(info.SC->getColonLoc(), " {");
73         TA.insert(info.Range.getEnd(), "}\n");
74         info.FixedBypass = true;
75       }
76       return true;
77     }
78   }
79 
80   return false;
81 }
82 
83 static void handleProtectedScopeError(CapturedDiagList::iterator &DiagI,
84                                       CapturedDiagList::iterator DiagE,
85                                       llvm::SmallVectorImpl<CaseInfo> &Cases,
86                                       TransformActions &TA) {
87   Transaction Trans(TA);
88   assert(DiagI->getID() == diag::err_switch_into_protected_scope);
89   SourceLocation ErrLoc = DiagI->getLocation();
90   bool handledAllNotes = true;
91   ++DiagI;
92   for (; DiagI != DiagE && DiagI->getLevel() == DiagnosticsEngine::Note;
93        ++DiagI) {
94     if (!handleProtectedNote(*DiagI, Cases, TA))
95       handledAllNotes = false;
96   }
97 
98   if (handledAllNotes)
99     TA.clearDiagnostic(diag::err_switch_into_protected_scope, ErrLoc);
100 }
101 
102 void ProtectedScopeTraverser::traverseBody(BodyContext &BodyCtx) {
103   MigrationPass &Pass = BodyCtx.getMigrationContext().Pass;
104   SmallVector<CaseInfo, 16> Cases;
105   CaseCollector(Cases).TraverseStmt(BodyCtx.getTopStmt());
106 
107   SourceRange BodyRange = BodyCtx.getTopStmt()->getSourceRange();
108   const CapturedDiagList &DiagList = Pass.getDiags();
109   CapturedDiagList::iterator I = DiagList.begin(), E = DiagList.end();
110   while (I != E) {
111     if (I->getID() == diag::err_switch_into_protected_scope &&
112         isInRange(I->getLocation(), BodyRange)) {
113       handleProtectedScopeError(I, E, Cases, Pass.TA);
114       continue;
115     }
116     ++I;
117   }
118 }
119