1 //===--- RefactoringCallbacks.cpp - Structural query framework ------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // 11 //===----------------------------------------------------------------------===// 12 #include "clang/Tooling/RefactoringCallbacks.h" 13 #include "clang/ASTMatchers/ASTMatchFinder.h" 14 #include "clang/Basic/SourceLocation.h" 15 #include "clang/Lex/Lexer.h" 16 17 using llvm::StringError; 18 using llvm::make_error; 19 20 namespace clang { 21 namespace tooling { 22 23 RefactoringCallback::RefactoringCallback() {} 24 tooling::Replacements &RefactoringCallback::getReplacements() { 25 return Replace; 26 } 27 28 ASTMatchRefactorer::ASTMatchRefactorer( 29 std::map<std::string, Replacements> &FileToReplaces) 30 : FileToReplaces(FileToReplaces) {} 31 32 void ASTMatchRefactorer::addDynamicMatcher( 33 const ast_matchers::internal::DynTypedMatcher &Matcher, 34 RefactoringCallback *Callback) { 35 MatchFinder.addDynamicMatcher(Matcher, Callback); 36 Callbacks.push_back(Callback); 37 } 38 39 class RefactoringASTConsumer : public ASTConsumer { 40 public: 41 explicit RefactoringASTConsumer(ASTMatchRefactorer &Refactoring) 42 : Refactoring(Refactoring) {} 43 44 void HandleTranslationUnit(ASTContext &Context) override { 45 // The ASTMatchRefactorer is re-used between translation units. 46 // Clear the matchers so that each Replacement is only emitted once. 47 for (const auto &Callback : Refactoring.Callbacks) { 48 Callback->getReplacements().clear(); 49 } 50 Refactoring.MatchFinder.matchAST(Context); 51 for (const auto &Callback : Refactoring.Callbacks) { 52 for (const auto &Replacement : Callback->getReplacements()) { 53 llvm::Error Err = 54 Refactoring.FileToReplaces[Replacement.getFilePath()].add( 55 Replacement); 56 if (Err) { 57 llvm::errs() << "Skipping replacement " << Replacement.toString() 58 << " due to this error:\n" 59 << toString(std::move(Err)) << "\n"; 60 } 61 } 62 } 63 } 64 65 private: 66 ASTMatchRefactorer &Refactoring; 67 }; 68 69 std::unique_ptr<ASTConsumer> ASTMatchRefactorer::newASTConsumer() { 70 return llvm::make_unique<RefactoringASTConsumer>(*this); 71 } 72 73 static Replacement replaceStmtWithText(SourceManager &Sources, const Stmt &From, 74 StringRef Text) { 75 return tooling::Replacement( 76 Sources, CharSourceRange::getTokenRange(From.getSourceRange()), Text); 77 } 78 static Replacement replaceStmtWithStmt(SourceManager &Sources, const Stmt &From, 79 const Stmt &To) { 80 return replaceStmtWithText( 81 Sources, From, 82 Lexer::getSourceText(CharSourceRange::getTokenRange(To.getSourceRange()), 83 Sources, LangOptions())); 84 } 85 86 ReplaceStmtWithText::ReplaceStmtWithText(StringRef FromId, StringRef ToText) 87 : FromId(FromId), ToText(ToText) {} 88 89 void ReplaceStmtWithText::run( 90 const ast_matchers::MatchFinder::MatchResult &Result) { 91 if (const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId)) { 92 auto Err = Replace.add(tooling::Replacement( 93 *Result.SourceManager, 94 CharSourceRange::getTokenRange(FromMatch->getSourceRange()), ToText)); 95 // FIXME: better error handling. For now, just print error message in the 96 // release version. 97 if (Err) { 98 llvm::errs() << llvm::toString(std::move(Err)) << "\n"; 99 assert(false); 100 } 101 } 102 } 103 104 ReplaceStmtWithStmt::ReplaceStmtWithStmt(StringRef FromId, StringRef ToId) 105 : FromId(FromId), ToId(ToId) {} 106 107 void ReplaceStmtWithStmt::run( 108 const ast_matchers::MatchFinder::MatchResult &Result) { 109 const Stmt *FromMatch = Result.Nodes.getNodeAs<Stmt>(FromId); 110 const Stmt *ToMatch = Result.Nodes.getNodeAs<Stmt>(ToId); 111 if (FromMatch && ToMatch) { 112 auto Err = Replace.add( 113 replaceStmtWithStmt(*Result.SourceManager, *FromMatch, *ToMatch)); 114 // FIXME: better error handling. For now, just print error message in the 115 // release version. 116 if (Err) { 117 llvm::errs() << llvm::toString(std::move(Err)) << "\n"; 118 assert(false); 119 } 120 } 121 } 122 123 ReplaceIfStmtWithItsBody::ReplaceIfStmtWithItsBody(StringRef Id, 124 bool PickTrueBranch) 125 : Id(Id), PickTrueBranch(PickTrueBranch) {} 126 127 void ReplaceIfStmtWithItsBody::run( 128 const ast_matchers::MatchFinder::MatchResult &Result) { 129 if (const IfStmt *Node = Result.Nodes.getNodeAs<IfStmt>(Id)) { 130 const Stmt *Body = PickTrueBranch ? Node->getThen() : Node->getElse(); 131 if (Body) { 132 auto Err = 133 Replace.add(replaceStmtWithStmt(*Result.SourceManager, *Node, *Body)); 134 // FIXME: better error handling. For now, just print error message in the 135 // release version. 136 if (Err) { 137 llvm::errs() << llvm::toString(std::move(Err)) << "\n"; 138 assert(false); 139 } 140 } else if (!PickTrueBranch) { 141 // If we want to use the 'else'-branch, but it doesn't exist, delete 142 // the whole 'if'. 143 auto Err = 144 Replace.add(replaceStmtWithText(*Result.SourceManager, *Node, "")); 145 // FIXME: better error handling. For now, just print error message in the 146 // release version. 147 if (Err) { 148 llvm::errs() << llvm::toString(std::move(Err)) << "\n"; 149 assert(false); 150 } 151 } 152 } 153 } 154 155 ReplaceNodeWithTemplate::ReplaceNodeWithTemplate( 156 llvm::StringRef FromId, std::vector<TemplateElement> Template) 157 : FromId(FromId), Template(std::move(Template)) {} 158 159 llvm::Expected<std::unique_ptr<ReplaceNodeWithTemplate>> 160 ReplaceNodeWithTemplate::create(StringRef FromId, StringRef ToTemplate) { 161 std::vector<TemplateElement> ParsedTemplate; 162 for (size_t Index = 0; Index < ToTemplate.size();) { 163 if (ToTemplate[Index] == '$') { 164 if (ToTemplate.substr(Index, 2) == "$$") { 165 Index += 2; 166 ParsedTemplate.push_back( 167 TemplateElement{TemplateElement::Literal, "$"}); 168 } else if (ToTemplate.substr(Index, 2) == "${") { 169 size_t EndOfIdentifier = ToTemplate.find("}", Index); 170 if (EndOfIdentifier == std::string::npos) { 171 return make_error<StringError>( 172 "Unterminated ${...} in replacement template near " + 173 ToTemplate.substr(Index), 174 llvm::inconvertibleErrorCode()); 175 } 176 std::string SourceNodeName = 177 ToTemplate.substr(Index + 2, EndOfIdentifier - Index - 2); 178 ParsedTemplate.push_back( 179 TemplateElement{TemplateElement::Identifier, SourceNodeName}); 180 Index = EndOfIdentifier + 1; 181 } else { 182 return make_error<StringError>( 183 "Invalid $ in replacement template near " + 184 ToTemplate.substr(Index), 185 llvm::inconvertibleErrorCode()); 186 } 187 } else { 188 size_t NextIndex = ToTemplate.find('$', Index + 1); 189 ParsedTemplate.push_back( 190 TemplateElement{TemplateElement::Literal, 191 ToTemplate.substr(Index, NextIndex - Index)}); 192 Index = NextIndex; 193 } 194 } 195 return std::unique_ptr<ReplaceNodeWithTemplate>( 196 new ReplaceNodeWithTemplate(FromId, std::move(ParsedTemplate))); 197 } 198 199 void ReplaceNodeWithTemplate::run( 200 const ast_matchers::MatchFinder::MatchResult &Result) { 201 const auto &NodeMap = Result.Nodes.getMap(); 202 203 std::string ToText; 204 for (const auto &Element : Template) { 205 switch (Element.Type) { 206 case TemplateElement::Literal: 207 ToText += Element.Value; 208 break; 209 case TemplateElement::Identifier: { 210 auto NodeIter = NodeMap.find(Element.Value); 211 if (NodeIter == NodeMap.end()) { 212 llvm::errs() << "Node " << Element.Value 213 << " used in replacement template not bound in Matcher \n"; 214 llvm::report_fatal_error("Unbound node in replacement template."); 215 } 216 CharSourceRange Source = 217 CharSourceRange::getTokenRange(NodeIter->second.getSourceRange()); 218 ToText += Lexer::getSourceText(Source, *Result.SourceManager, 219 Result.Context->getLangOpts()); 220 break; 221 } 222 } 223 } 224 if (NodeMap.count(FromId) == 0) { 225 llvm::errs() << "Node to be replaced " << FromId 226 << " not bound in query.\n"; 227 llvm::report_fatal_error("FromId node not bound in MatchResult"); 228 } 229 auto Replacement = 230 tooling::Replacement(*Result.SourceManager, &NodeMap.at(FromId), ToText, 231 Result.Context->getLangOpts()); 232 llvm::Error Err = Replace.add(Replacement); 233 if (Err) { 234 llvm::errs() << "Query and replace failed in " << Replacement.getFilePath() 235 << "! " << llvm::toString(std::move(Err)) << "\n"; 236 llvm::report_fatal_error("Replacement failed"); 237 } 238 } 239 240 } // end namespace tooling 241 } // end namespace clang 242