1 //===- TreeTest.cpp -------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "clang/Tooling/Syntax/Tree.h"
10 #include "clang/AST/ASTConsumer.h"
11 #include "clang/AST/Decl.h"
12 #include "clang/AST/Stmt.h"
13 #include "clang/Basic/LLVM.h"
14 #include "clang/Frontend/CompilerInstance.h"
15 #include "clang/Frontend/CompilerInvocation.h"
16 #include "clang/Frontend/FrontendAction.h"
17 #include "clang/Lex/PreprocessorOptions.h"
18 #include "clang/Tooling/Core/Replacement.h"
19 #include "clang/Tooling/Syntax/BuildTree.h"
20 #include "clang/Tooling/Syntax/Mutations.h"
21 #include "clang/Tooling/Syntax/Nodes.h"
22 #include "clang/Tooling/Syntax/Tokens.h"
23 #include "clang/Tooling/Tooling.h"
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/Casting.h"
28 #include "llvm/Support/Error.h"
29 #include "llvm/Testing/Support/Annotations.h"
30 #include "gmock/gmock.h"
31 #include "gtest/gtest.h"
32 #include <cstdlib>
33 
34 using namespace clang;
35 
36 namespace {
37 static llvm::ArrayRef<syntax::Token> tokens(syntax::Node *N) {
38   assert(N->isOriginal() && "tokens of modified nodes are not well-defined");
39   if (auto *L = dyn_cast<syntax::Leaf>(N))
40     return llvm::makeArrayRef(L->token(), 1);
41   auto *T = cast<syntax::Tree>(N);
42   return llvm::makeArrayRef(T->firstLeaf()->token(),
43                             T->lastLeaf()->token() + 1);
44 }
45 
46 class SyntaxTreeTest : public ::testing::Test {
47 protected:
48   // Build a syntax tree for the code.
49   syntax::TranslationUnit *buildTree(llvm::StringRef Code) {
50     // FIXME: this code is almost the identical to the one in TokensTest. Share
51     //        it.
52     class BuildSyntaxTree : public ASTConsumer {
53     public:
54       BuildSyntaxTree(syntax::TranslationUnit *&Root,
55                       std::unique_ptr<syntax::Arena> &Arena,
56                       std::unique_ptr<syntax::TokenCollector> Tokens)
57           : Root(Root), Arena(Arena), Tokens(std::move(Tokens)) {
58         assert(this->Tokens);
59       }
60 
61       void HandleTranslationUnit(ASTContext &Ctx) override {
62         Arena = std::make_unique<syntax::Arena>(Ctx.getSourceManager(),
63                                                 Ctx.getLangOpts(),
64                                                 std::move(*Tokens).consume());
65         Tokens = nullptr; // make sure we fail if this gets called twice.
66         Root = syntax::buildSyntaxTree(*Arena, *Ctx.getTranslationUnitDecl());
67       }
68 
69     private:
70       syntax::TranslationUnit *&Root;
71       std::unique_ptr<syntax::Arena> &Arena;
72       std::unique_ptr<syntax::TokenCollector> Tokens;
73     };
74 
75     class BuildSyntaxTreeAction : public ASTFrontendAction {
76     public:
77       BuildSyntaxTreeAction(syntax::TranslationUnit *&Root,
78                             std::unique_ptr<syntax::Arena> &Arena)
79           : Root(Root), Arena(Arena) {}
80 
81       std::unique_ptr<ASTConsumer>
82       CreateASTConsumer(CompilerInstance &CI, StringRef InFile) override {
83         // We start recording the tokens, ast consumer will take on the result.
84         auto Tokens =
85             std::make_unique<syntax::TokenCollector>(CI.getPreprocessor());
86         return std::make_unique<BuildSyntaxTree>(Root, Arena,
87                                                  std::move(Tokens));
88       }
89 
90     private:
91       syntax::TranslationUnit *&Root;
92       std::unique_ptr<syntax::Arena> &Arena;
93     };
94 
95     constexpr const char *FileName = "./input.cpp";
96     FS->addFile(FileName, time_t(), llvm::MemoryBuffer::getMemBufferCopy(""));
97     if (!Diags->getClient())
98       Diags->setClient(new IgnoringDiagConsumer);
99     // Prepare to run a compiler.
100     std::vector<const char *> Args = {"syntax-test", "-std=c++11",
101                                       "-fsyntax-only", FileName};
102     Invocation = createInvocationFromCommandLine(Args, Diags, FS);
103     assert(Invocation);
104     Invocation->getFrontendOpts().DisableFree = false;
105     Invocation->getPreprocessorOpts().addRemappedFile(
106         FileName, llvm::MemoryBuffer::getMemBufferCopy(Code).release());
107     CompilerInstance Compiler;
108     Compiler.setInvocation(Invocation);
109     Compiler.setDiagnostics(Diags.get());
110     Compiler.setFileManager(FileMgr.get());
111     Compiler.setSourceManager(SourceMgr.get());
112 
113     syntax::TranslationUnit *Root = nullptr;
114     BuildSyntaxTreeAction Recorder(Root, this->Arena);
115     if (!Compiler.ExecuteAction(Recorder)) {
116       ADD_FAILURE() << "failed to run the frontend";
117       std::abort();
118     }
119     return Root;
120   }
121 
122   // Adds a file to the test VFS.
123   void addFile(llvm::StringRef Path, llvm::StringRef Contents) {
124     if (!FS->addFile(Path, time_t(),
125                      llvm::MemoryBuffer::getMemBufferCopy(Contents))) {
126       ADD_FAILURE() << "could not add a file to VFS: " << Path;
127     }
128   }
129 
130   /// Finds the deepest node in the tree that covers exactly \p R.
131   /// FIXME: implement this efficiently and move to public syntax tree API.
132   syntax::Node *nodeByRange(llvm::Annotations::Range R, syntax::Node *Root) {
133     llvm::ArrayRef<syntax::Token> Toks = tokens(Root);
134 
135     if (Toks.front().location().isFileID() &&
136         Toks.back().location().isFileID() &&
137         syntax::Token::range(*SourceMgr, Toks.front(), Toks.back()) ==
138             syntax::FileRange(SourceMgr->getMainFileID(), R.Begin, R.End))
139       return Root;
140 
141     auto *T = dyn_cast<syntax::Tree>(Root);
142     if (!T)
143       return nullptr;
144     for (auto *C = T->firstChild(); C != nullptr; C = C->nextSibling()) {
145       if (auto *Result = nodeByRange(R, C))
146         return Result;
147     }
148     return nullptr;
149   }
150 
151   // Data fields.
152   llvm::IntrusiveRefCntPtr<DiagnosticsEngine> Diags =
153       new DiagnosticsEngine(new DiagnosticIDs, new DiagnosticOptions);
154   IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> FS =
155       new llvm::vfs::InMemoryFileSystem;
156   llvm::IntrusiveRefCntPtr<FileManager> FileMgr =
157       new FileManager(FileSystemOptions(), FS);
158   llvm::IntrusiveRefCntPtr<SourceManager> SourceMgr =
159       new SourceManager(*Diags, *FileMgr);
160   std::shared_ptr<CompilerInvocation> Invocation;
161   // Set after calling buildTree().
162   std::unique_ptr<syntax::Arena> Arena;
163 };
164 
165 TEST_F(SyntaxTreeTest, Basic) {
166   std::pair</*Input*/ std::string, /*Expected*/ std::string> Cases[] = {
167       {
168           R"cpp(
169 int main() {}
170 void foo() {}
171     )cpp",
172           R"txt(
173 *: TranslationUnit
174 |-SimpleDeclaration
175 | |-int
176 | |-main
177 | |-(
178 | |-)
179 | `-CompoundStatement
180 |   |-{
181 |   `-}
182 `-SimpleDeclaration
183   |-void
184   |-foo
185   |-(
186   |-)
187   `-CompoundStatement
188     |-{
189     `-}
190 )txt"},
191       // if.
192       {
193           R"cpp(
194 int main() {
195   if (true) {}
196   if (true) {} else if (false) {}
197 }
198         )cpp",
199           R"txt(
200 *: TranslationUnit
201 `-SimpleDeclaration
202   |-int
203   |-main
204   |-(
205   |-)
206   `-CompoundStatement
207     |-{
208     |-IfStatement
209     | |-if
210     | |-(
211     | |-UnknownExpression
212     | | `-true
213     | |-)
214     | `-CompoundStatement
215     |   |-{
216     |   `-}
217     |-IfStatement
218     | |-if
219     | |-(
220     | |-UnknownExpression
221     | | `-true
222     | |-)
223     | |-CompoundStatement
224     | | |-{
225     | | `-}
226     | |-else
227     | `-IfStatement
228     |   |-if
229     |   |-(
230     |   |-UnknownExpression
231     |   | `-false
232     |   |-)
233     |   `-CompoundStatement
234     |     |-{
235     |     `-}
236     `-}
237         )txt"},
238       // for.
239       {R"cpp(
240 void test() {
241   for (;;)  {}
242 }
243 )cpp",
244        R"txt(
245 *: TranslationUnit
246 `-SimpleDeclaration
247   |-void
248   |-test
249   |-(
250   |-)
251   `-CompoundStatement
252     |-{
253     |-ForStatement
254     | |-for
255     | |-(
256     | |-;
257     | |-;
258     | |-)
259     | `-CompoundStatement
260     |   |-{
261     |   `-}
262     `-}
263         )txt"},
264       // declaration statement.
265       {"void test() { int a = 10; }",
266        R"txt(
267 *: TranslationUnit
268 `-SimpleDeclaration
269   |-void
270   |-test
271   |-(
272   |-)
273   `-CompoundStatement
274     |-{
275     |-DeclarationStatement
276     | |-SimpleDeclaration
277     | | |-int
278     | | |-a
279     | | |-=
280     | | `-UnknownExpression
281     | |   `-10
282     | `-;
283     `-}
284 )txt"},
285       {"void test() { ; }", R"txt(
286 *: TranslationUnit
287 `-SimpleDeclaration
288   |-void
289   |-test
290   |-(
291   |-)
292   `-CompoundStatement
293     |-{
294     |-EmptyStatement
295     | `-;
296     `-}
297 )txt"},
298       // switch, case and default.
299       {R"cpp(
300 void test() {
301   switch (true) {
302     case 0:
303     default:;
304   }
305 }
306 )cpp",
307        R"txt(
308 *: TranslationUnit
309 `-SimpleDeclaration
310   |-void
311   |-test
312   |-(
313   |-)
314   `-CompoundStatement
315     |-{
316     |-SwitchStatement
317     | |-switch
318     | |-(
319     | |-UnknownExpression
320     | | `-true
321     | |-)
322     | `-CompoundStatement
323     |   |-{
324     |   |-CaseStatement
325     |   | |-case
326     |   | |-UnknownExpression
327     |   | | `-0
328     |   | |-:
329     |   | `-DefaultStatement
330     |   |   |-default
331     |   |   |-:
332     |   |   `-EmptyStatement
333     |   |     `-;
334     |   `-}
335     `-}
336 )txt"},
337       // while.
338       {R"cpp(
339 void test() {
340   while (true) { continue; break; }
341 }
342 )cpp",
343        R"txt(
344 *: TranslationUnit
345 `-SimpleDeclaration
346   |-void
347   |-test
348   |-(
349   |-)
350   `-CompoundStatement
351     |-{
352     |-WhileStatement
353     | |-while
354     | |-(
355     | |-UnknownExpression
356     | | `-true
357     | |-)
358     | `-CompoundStatement
359     |   |-{
360     |   |-ContinueStatement
361     |   | |-continue
362     |   | `-;
363     |   |-BreakStatement
364     |   | |-break
365     |   | `-;
366     |   `-}
367     `-}
368 )txt"},
369       // return.
370       {R"cpp(
371 int test() { return 1; }
372       )cpp",
373        R"txt(
374 *: TranslationUnit
375 `-SimpleDeclaration
376   |-int
377   |-test
378   |-(
379   |-)
380   `-CompoundStatement
381     |-{
382     |-ReturnStatement
383     | |-return
384     | |-UnknownExpression
385     | | `-1
386     | `-;
387     `-}
388 )txt"},
389       // Range-based for.
390       {R"cpp(
391 void test() {
392   int a[3];
393   for (int x : a) ;
394 }
395       )cpp",
396        R"txt(
397 *: TranslationUnit
398 `-SimpleDeclaration
399   |-void
400   |-test
401   |-(
402   |-)
403   `-CompoundStatement
404     |-{
405     |-DeclarationStatement
406     | |-SimpleDeclaration
407     | | |-int
408     | | |-a
409     | | |-[
410     | | |-UnknownExpression
411     | | | `-3
412     | | `-]
413     | `-;
414     |-RangeBasedForStatement
415     | |-for
416     | |-(
417     | |-SimpleDeclaration
418     | | |-int
419     | | |-x
420     | | `-:
421     | |-UnknownExpression
422     | | `-a
423     | |-)
424     | `-EmptyStatement
425     |   `-;
426     `-}
427        )txt"},
428       // Unhandled statements should end up as 'unknown statement'.
429       // This example uses a 'label statement', which does not yet have a syntax
430       // counterpart.
431       {"void main() { foo: return 100; }", R"txt(
432 *: TranslationUnit
433 `-SimpleDeclaration
434   |-void
435   |-main
436   |-(
437   |-)
438   `-CompoundStatement
439     |-{
440     |-UnknownStatement
441     | |-foo
442     | |-:
443     | `-ReturnStatement
444     |   |-return
445     |   |-UnknownExpression
446     |   | `-100
447     |   `-;
448     `-}
449 )txt"},
450       // expressions should be wrapped in 'ExpressionStatement' when they appear
451       // in a statement position.
452       {R"cpp(
453 void test() {
454   test();
455   if (true) test(); else test();
456 }
457     )cpp",
458        R"txt(
459 *: TranslationUnit
460 `-SimpleDeclaration
461   |-void
462   |-test
463   |-(
464   |-)
465   `-CompoundStatement
466     |-{
467     |-ExpressionStatement
468     | |-UnknownExpression
469     | | |-test
470     | | |-(
471     | | `-)
472     | `-;
473     |-IfStatement
474     | |-if
475     | |-(
476     | |-UnknownExpression
477     | | `-true
478     | |-)
479     | |-ExpressionStatement
480     | | |-UnknownExpression
481     | | | |-test
482     | | | |-(
483     | | | `-)
484     | | `-;
485     | |-else
486     | `-ExpressionStatement
487     |   |-UnknownExpression
488     |   | |-test
489     |   | |-(
490     |   | `-)
491     |   `-;
492     `-}
493 )txt"},
494       // Multiple declarators group into a single SimpleDeclaration.
495       {R"cpp(
496       int *a, b;
497   )cpp",
498        R"txt(
499 *: TranslationUnit
500 `-SimpleDeclaration
501   |-int
502   |-*
503   |-a
504   |-,
505   |-b
506   `-;
507   )txt"},
508       {R"cpp(
509     typedef int *a, b;
510   )cpp",
511        R"txt(
512 *: TranslationUnit
513 `-SimpleDeclaration
514   |-typedef
515   |-int
516   |-*
517   |-a
518   |-,
519   |-b
520   `-;
521   )txt"},
522       // Multiple declarators inside a statement.
523       {R"cpp(
524 void foo() {
525       int *a, b;
526       typedef int *ta, tb;
527 }
528   )cpp",
529        R"txt(
530 *: TranslationUnit
531 `-SimpleDeclaration
532   |-void
533   |-foo
534   |-(
535   |-)
536   `-CompoundStatement
537     |-{
538     |-DeclarationStatement
539     | |-SimpleDeclaration
540     | | |-int
541     | | |-*
542     | | |-a
543     | | |-,
544     | | `-b
545     | `-;
546     |-DeclarationStatement
547     | |-SimpleDeclaration
548     | | |-typedef
549     | | |-int
550     | | |-*
551     | | |-ta
552     | | |-,
553     | | `-tb
554     | `-;
555     `-}
556   )txt"},
557       {R"cpp(
558 namespace a { namespace b {} }
559 namespace a::b {}
560 namespace {}
561 
562 namespace foo = a;
563     )cpp",
564        R"txt(
565 *: TranslationUnit
566 |-NamespaceDefinition
567 | |-namespace
568 | |-a
569 | |-{
570 | |-NamespaceDefinition
571 | | |-namespace
572 | | |-b
573 | | |-{
574 | | `-}
575 | `-}
576 |-NamespaceDefinition
577 | |-namespace
578 | |-a
579 | |-::
580 | |-b
581 | |-{
582 | `-}
583 |-NamespaceDefinition
584 | |-namespace
585 | |-{
586 | `-}
587 `-NamespaceAliasDefinition
588   |-namespace
589   |-foo
590   |-=
591   |-a
592   `-;
593 )txt"},
594       {R"cpp(
595 namespace ns {}
596 using namespace ::ns;
597     )cpp",
598        R"txt(
599 *: TranslationUnit
600 |-NamespaceDefinition
601 | |-namespace
602 | |-ns
603 | |-{
604 | `-}
605 `-UsingNamespaceDirective
606   |-using
607   |-namespace
608   |-::
609   |-ns
610   `-;
611        )txt"},
612       {R"cpp(
613 namespace ns { int a; }
614 using ns::a;
615     )cpp",
616        R"txt(
617 *: TranslationUnit
618 |-NamespaceDefinition
619 | |-namespace
620 | |-ns
621 | |-{
622 | |-SimpleDeclaration
623 | | |-int
624 | | |-a
625 | | `-;
626 | `-}
627 `-UsingDeclaration
628   |-using
629   |-ns
630   |-::
631   |-a
632   `-;
633        )txt"},
634       {R"cpp(
635 template <class T> struct X {
636   using T::foo;
637   using typename T::bar;
638 };
639     )cpp",
640        R"txt(
641 *: TranslationUnit
642 `-UnknownDeclaration
643   |-template
644   |-<
645   |-UnknownDeclaration
646   | |-class
647   | `-T
648   |->
649   |-struct
650   |-X
651   |-{
652   |-UsingDeclaration
653   | |-using
654   | |-T
655   | |-::
656   | |-foo
657   | `-;
658   |-UsingDeclaration
659   | |-using
660   | |-typename
661   | |-T
662   | |-::
663   | |-bar
664   | `-;
665   |-}
666   `-;
667        )txt"},
668       {R"cpp(
669 using type = int;
670     )cpp",
671        R"txt(
672 *: TranslationUnit
673 `-TypeAliasDeclaration
674   |-using
675   |-type
676   |-=
677   |-int
678   `-;
679        )txt"},
680       {R"cpp(
681 ;
682     )cpp",
683        R"txt(
684 *: TranslationUnit
685 `-EmptyDeclaration
686   `-;
687        )txt"},
688       {R"cpp(
689 static_assert(true, "message");
690 static_assert(true);
691     )cpp",
692        R"txt(
693 *: TranslationUnit
694 |-StaticAssertDeclaration
695 | |-static_assert
696 | |-(
697 | |-UnknownExpression
698 | | `-true
699 | |-,
700 | |-UnknownExpression
701 | | `-"message"
702 | |-)
703 | `-;
704 `-StaticAssertDeclaration
705   |-static_assert
706   |-(
707   |-UnknownExpression
708   | `-true
709   |-)
710   `-;
711        )txt"},
712       {R"cpp(
713 extern "C" int a;
714 extern "C" { int b; int c; }
715     )cpp",
716        R"txt(
717 *: TranslationUnit
718 |-LinkageSpecificationDeclaration
719 | |-extern
720 | |-"C"
721 | `-SimpleDeclaration
722 |   |-int
723 |   |-a
724 |   `-;
725 `-LinkageSpecificationDeclaration
726   |-extern
727   |-"C"
728   |-{
729   |-SimpleDeclaration
730   | |-int
731   | |-b
732   | `-;
733   |-SimpleDeclaration
734   | |-int
735   | |-c
736   | `-;
737   `-}
738        )txt"},
739       // Some nodes are non-modifiable, they are marked with 'I:'.
740       {R"cpp(
741 #define HALF_IF if (1+
742 #define HALF_IF_2 1) {}
743 void test() {
744   HALF_IF HALF_IF_2 else {}
745 })cpp",
746        R"txt(
747 *: TranslationUnit
748 `-SimpleDeclaration
749   |-void
750   |-test
751   |-(
752   |-)
753   `-CompoundStatement
754     |-{
755     |-IfStatement
756     | |-I: if
757     | |-I: (
758     | |-I: UnknownExpression
759     | | |-I: 1
760     | | |-I: +
761     | | `-I: 1
762     | |-I: )
763     | |-I: CompoundStatement
764     | | |-I: {
765     | | `-I: }
766     | |-else
767     | `-CompoundStatement
768     |   |-{
769     |   `-}
770     `-}
771        )txt"},
772       // All nodes can be mutated.
773       {R"cpp(
774 #define OPEN {
775 #define CLOSE }
776 
777 void test() {
778   OPEN
779     1;
780   CLOSE
781 
782   OPEN
783     2;
784   }
785 }
786 )cpp",
787        R"txt(
788 *: TranslationUnit
789 `-SimpleDeclaration
790   |-void
791   |-test
792   |-(
793   |-)
794   `-CompoundStatement
795     |-{
796     |-CompoundStatement
797     | |-{
798     | |-ExpressionStatement
799     | | |-UnknownExpression
800     | | | `-1
801     | | `-;
802     | `-}
803     |-CompoundStatement
804     | |-{
805     | |-ExpressionStatement
806     | | |-UnknownExpression
807     | | | `-2
808     | | `-;
809     | `-}
810     `-}
811        )txt"},
812   };
813 
814   for (const auto &T : Cases) {
815     SCOPED_TRACE(T.first);
816 
817     auto *Root = buildTree(T.first);
818     std::string Expected = llvm::StringRef(T.second).trim().str();
819     std::string Actual = llvm::StringRef(Root->dump(*Arena)).trim();
820     EXPECT_EQ(Expected, Actual) << "the resulting dump is:\n" << Actual;
821   }
822 }
823 
824 TEST_F(SyntaxTreeTest, Mutations) {
825   using Transformation = std::function<void(
826       const llvm::Annotations & /*Input*/, syntax::TranslationUnit * /*Root*/)>;
827   auto CheckTransformation = [this](std::string Input, std::string Expected,
828                                     Transformation Transform) -> void {
829     llvm::Annotations Source(Input);
830     auto *Root = buildTree(Source.code());
831 
832     Transform(Source, Root);
833 
834     auto Replacements = syntax::computeReplacements(*Arena, *Root);
835     auto Output = tooling::applyAllReplacements(Source.code(), Replacements);
836     if (!Output) {
837       ADD_FAILURE() << "could not apply replacements: "
838                     << llvm::toString(Output.takeError());
839       return;
840     }
841 
842     EXPECT_EQ(Expected, *Output) << "input is:\n" << Input;
843   };
844 
845   // Removes the selected statement. Input should have exactly one selected
846   // range and it should correspond to a single statement.
847   auto RemoveStatement = [this](const llvm::Annotations &Input,
848                                 syntax::TranslationUnit *TU) {
849     auto *S = cast<syntax::Statement>(nodeByRange(Input.range(), TU));
850     ASSERT_TRUE(S->canModify()) << "cannot remove a statement";
851     syntax::removeStatement(*Arena, S);
852   };
853 
854   std::vector<std::pair<std::string /*Input*/, std::string /*Expected*/>>
855       Cases = {
856           {"void test() { [[100+100;]] test(); }", "void test() {  test(); }"},
857           {"void test() { if (true) [[{}]] else {} }",
858            "void test() { if (true) ; else {} }"},
859           {"void test() { [[;]] }", "void test() {  }"}};
860   for (const auto &C : Cases)
861     CheckTransformation(C.first, C.second, RemoveStatement);
862 }
863 
864 } // namespace
865