1 //===- unittest/Tooling/RefactoringTestActionRulesTest.cpp ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "ReplacementTest.h"
10 #include "RewriterTestContext.h"
11 #include "clang/Tooling/Refactoring.h"
12 #include "clang/Tooling/Refactoring/Extract/Extract.h"
13 #include "clang/Tooling/Refactoring/RefactoringAction.h"
14 #include "clang/Tooling/Refactoring/RefactoringDiagnostic.h"
15 #include "clang/Tooling/Refactoring/Rename/SymbolName.h"
16 #include "clang/Tooling/Tooling.h"
17 #include "llvm/Support/Errc.h"
18 #include "gtest/gtest.h"
19 
20 using namespace clang;
21 using namespace tooling;
22 
23 namespace {
24 
25 class RefactoringActionRulesTest : public ::testing::Test {
26 protected:
SetUp()27   void SetUp() override {
28     Context.Sources.setMainFileID(
29         Context.createInMemoryFile("input.cpp", DefaultCode));
30   }
31 
32   RewriterTestContext Context;
33   std::string DefaultCode = std::string(100, 'a');
34 };
35 
36 Expected<AtomicChanges>
createReplacements(const std::unique_ptr<RefactoringActionRule> & Rule,RefactoringRuleContext & Context)37 createReplacements(const std::unique_ptr<RefactoringActionRule> &Rule,
38                    RefactoringRuleContext &Context) {
39   class Consumer final : public RefactoringResultConsumer {
40     void handleError(llvm::Error Err) override { Result = std::move(Err); }
41 
42     void handle(AtomicChanges SourceReplacements) override {
43       Result = std::move(SourceReplacements);
44     }
45     void handle(SymbolOccurrences Occurrences) override {
46       RefactoringResultConsumer::handle(std::move(Occurrences));
47     }
48 
49   public:
50     Optional<Expected<AtomicChanges>> Result;
51   };
52 
53   Consumer C;
54   Rule->invoke(C, Context);
55   return std::move(*C.Result);
56 }
57 
TEST_F(RefactoringActionRulesTest,MyFirstRefactoringRule)58 TEST_F(RefactoringActionRulesTest, MyFirstRefactoringRule) {
59   class ReplaceAWithB : public SourceChangeRefactoringRule {
60     std::pair<SourceRange, int> Selection;
61 
62   public:
63     ReplaceAWithB(std::pair<SourceRange, int> Selection)
64         : Selection(Selection) {}
65 
66     static Expected<ReplaceAWithB>
67     initiate(RefactoringRuleContext &Cotnext,
68              std::pair<SourceRange, int> Selection) {
69       return ReplaceAWithB(Selection);
70     }
71 
72     Expected<AtomicChanges>
73     createSourceReplacements(RefactoringRuleContext &Context) {
74       const SourceManager &SM = Context.getSources();
75       SourceLocation Loc =
76           Selection.first.getBegin().getLocWithOffset(Selection.second);
77       AtomicChange Change(SM, Loc);
78       llvm::Error E = Change.replace(SM, Loc, 1, "b");
79       if (E)
80         return std::move(E);
81       return AtomicChanges{Change};
82     }
83   };
84 
85   class SelectionRequirement : public SourceRangeSelectionRequirement {
86   public:
87     Expected<std::pair<SourceRange, int>>
88     evaluate(RefactoringRuleContext &Context) const {
89       Expected<SourceRange> R =
90           SourceRangeSelectionRequirement::evaluate(Context);
91       if (!R)
92         return R.takeError();
93       return std::make_pair(*R, 20);
94     }
95   };
96   auto Rule =
97       createRefactoringActionRule<ReplaceAWithB>(SelectionRequirement());
98 
99   // When the requirements are satisfied, the rule's function must be invoked.
100   {
101     RefactoringRuleContext RefContext(Context.Sources);
102     SourceLocation Cursor =
103         Context.Sources.getLocForStartOfFile(Context.Sources.getMainFileID())
104             .getLocWithOffset(10);
105     RefContext.setSelectionRange({Cursor, Cursor});
106 
107     Expected<AtomicChanges> ErrorOrResult =
108         createReplacements(Rule, RefContext);
109     ASSERT_FALSE(!ErrorOrResult);
110     AtomicChanges Result = std::move(*ErrorOrResult);
111     ASSERT_EQ(Result.size(), 1u);
112     std::string YAMLString =
113         const_cast<AtomicChange &>(Result[0]).toYAMLString();
114 
115     ASSERT_STREQ("---\n"
116                  "Key:             'input.cpp:30'\n"
117                  "FilePath:        input.cpp\n"
118                  "Error:           ''\n"
119                  "InsertedHeaders: []\n"
120                  "RemovedHeaders:  []\n"
121                  "Replacements:\n"
122                  "  - FilePath:        input.cpp\n"
123                  "    Offset:          30\n"
124                  "    Length:          1\n"
125                  "    ReplacementText: b\n"
126                  "...\n",
127                  YAMLString.c_str());
128   }
129 
130   // When one of the requirements is not satisfied, invoke should return a
131   // valid error.
132   {
133     RefactoringRuleContext RefContext(Context.Sources);
134     Expected<AtomicChanges> ErrorOrResult =
135         createReplacements(Rule, RefContext);
136 
137     ASSERT_TRUE(!ErrorOrResult);
138     unsigned DiagID;
139     llvm::handleAllErrors(ErrorOrResult.takeError(),
140                           [&](DiagnosticError &Error) {
141                             DiagID = Error.getDiagnostic().second.getDiagID();
142                           });
143     EXPECT_EQ(DiagID, diag::err_refactor_no_selection);
144   }
145 }
146 
TEST_F(RefactoringActionRulesTest,ReturnError)147 TEST_F(RefactoringActionRulesTest, ReturnError) {
148   class ErrorRule : public SourceChangeRefactoringRule {
149   public:
150     static Expected<ErrorRule> initiate(RefactoringRuleContext &,
151                                         SourceRange R) {
152       return ErrorRule(R);
153     }
154 
155     ErrorRule(SourceRange R) {}
156     Expected<AtomicChanges> createSourceReplacements(RefactoringRuleContext &) {
157       return llvm::make_error<llvm::StringError>(
158           "Error", llvm::make_error_code(llvm::errc::invalid_argument));
159     }
160   };
161 
162   auto Rule =
163       createRefactoringActionRule<ErrorRule>(SourceRangeSelectionRequirement());
164   RefactoringRuleContext RefContext(Context.Sources);
165   SourceLocation Cursor =
166       Context.Sources.getLocForStartOfFile(Context.Sources.getMainFileID());
167   RefContext.setSelectionRange({Cursor, Cursor});
168   Expected<AtomicChanges> Result = createReplacements(Rule, RefContext);
169 
170   ASSERT_TRUE(!Result);
171   std::string Message;
172   llvm::handleAllErrors(Result.takeError(), [&](llvm::StringError &Error) {
173     Message = Error.getMessage();
174   });
175   EXPECT_EQ(Message, "Error");
176 }
177 
findOccurrences(RefactoringActionRule & Rule,RefactoringRuleContext & Context)178 Optional<SymbolOccurrences> findOccurrences(RefactoringActionRule &Rule,
179                                             RefactoringRuleContext &Context) {
180   class Consumer final : public RefactoringResultConsumer {
181     void handleError(llvm::Error) override {}
182     void handle(SymbolOccurrences Occurrences) override {
183       Result = std::move(Occurrences);
184     }
185     void handle(AtomicChanges Changes) override {
186       RefactoringResultConsumer::handle(std::move(Changes));
187     }
188 
189   public:
190     Optional<SymbolOccurrences> Result;
191   };
192 
193   Consumer C;
194   Rule.invoke(C, Context);
195   return std::move(C.Result);
196 }
197 
TEST_F(RefactoringActionRulesTest,ReturnSymbolOccurrences)198 TEST_F(RefactoringActionRulesTest, ReturnSymbolOccurrences) {
199   class FindOccurrences : public FindSymbolOccurrencesRefactoringRule {
200     SourceRange Selection;
201 
202   public:
203     FindOccurrences(SourceRange Selection) : Selection(Selection) {}
204 
205     static Expected<FindOccurrences> initiate(RefactoringRuleContext &,
206                                               SourceRange Selection) {
207       return FindOccurrences(Selection);
208     }
209 
210     Expected<SymbolOccurrences>
211     findSymbolOccurrences(RefactoringRuleContext &) override {
212       SymbolOccurrences Occurrences;
213       Occurrences.push_back(SymbolOccurrence(SymbolName("test"),
214                                              SymbolOccurrence::MatchingSymbol,
215                                              Selection.getBegin()));
216       return std::move(Occurrences);
217     }
218   };
219 
220   auto Rule = createRefactoringActionRule<FindOccurrences>(
221       SourceRangeSelectionRequirement());
222 
223   RefactoringRuleContext RefContext(Context.Sources);
224   SourceLocation Cursor =
225       Context.Sources.getLocForStartOfFile(Context.Sources.getMainFileID());
226   RefContext.setSelectionRange({Cursor, Cursor});
227   Optional<SymbolOccurrences> Result = findOccurrences(*Rule, RefContext);
228 
229   ASSERT_FALSE(!Result);
230   SymbolOccurrences Occurrences = std::move(*Result);
231   EXPECT_EQ(Occurrences.size(), 1u);
232   EXPECT_EQ(Occurrences[0].getKind(), SymbolOccurrence::MatchingSymbol);
233   EXPECT_EQ(Occurrences[0].getNameRanges().size(), 1u);
234   EXPECT_EQ(Occurrences[0].getNameRanges()[0],
235             SourceRange(Cursor, Cursor.getLocWithOffset(strlen("test"))));
236 }
237 
TEST_F(RefactoringActionRulesTest,EditorCommandBinding)238 TEST_F(RefactoringActionRulesTest, EditorCommandBinding) {
239   const RefactoringDescriptor &Descriptor = ExtractFunction::describe();
240   EXPECT_EQ(Descriptor.Name, "extract-function");
241   EXPECT_EQ(
242       Descriptor.Description,
243       "(WIP action; use with caution!) Extracts code into a new function");
244   EXPECT_EQ(Descriptor.Title, "Extract Function");
245 }
246 
247 } // end anonymous namespace
248