1 //===- CGSCCPassManagerTest.cpp -------------------------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #include "llvm/Analysis/CGSCCPassManager.h"
11 #include "llvm/Analysis/LazyCallGraph.h"
12 #include "llvm/AsmParser/Parser.h"
13 #include "llvm/IR/Function.h"
14 #include "llvm/IR/InstIterator.h"
15 #include "llvm/IR/LLVMContext.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/IR/PassManager.h"
18 #include "llvm/Support/SourceMgr.h"
19 #include "gtest/gtest.h"
20 
21 using namespace llvm;
22 
23 namespace {
24 
25 class TestModuleAnalysis {
26 public:
27   struct Result {
28     Result(int Count) : FunctionCount(Count) {}
29     int FunctionCount;
30   };
31 
32   static void *ID() { return (void *)&PassID; }
33   static StringRef name() { return "TestModuleAnalysis"; }
34 
35   TestModuleAnalysis(int &Runs) : Runs(Runs) {}
36 
37   Result run(Module &M, ModuleAnalysisManager &AM) {
38     ++Runs;
39     return Result(M.size());
40   }
41 
42 private:
43   static char PassID;
44 
45   int &Runs;
46 };
47 
48 char TestModuleAnalysis::PassID;
49 
50 class TestSCCAnalysis {
51 public:
52   struct Result {
53     Result(int Count) : FunctionCount(Count) {}
54     int FunctionCount;
55   };
56 
57   static void *ID() { return (void *)&PassID; }
58   static StringRef name() { return "TestSCCAnalysis"; }
59 
60   TestSCCAnalysis(int &Runs) : Runs(Runs) {}
61 
62   Result run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM) {
63     ++Runs;
64     return Result(C.size());
65   }
66 
67 private:
68   static char PassID;
69 
70   int &Runs;
71 };
72 
73 char TestSCCAnalysis::PassID;
74 
75 class TestFunctionAnalysis {
76 public:
77   struct Result {
78     Result(int Count) : InstructionCount(Count) {}
79     int InstructionCount;
80   };
81 
82   static void *ID() { return (void *)&PassID; }
83   static StringRef name() { return "TestFunctionAnalysis"; }
84 
85   TestFunctionAnalysis(int &Runs) : Runs(Runs) {}
86 
87   Result run(Function &F, FunctionAnalysisManager &AM) {
88     ++Runs;
89     int Count = 0;
90     for (Instruction &I : instructions(F)) {
91       (void)I;
92       ++Count;
93     }
94     return Result(Count);
95   }
96 
97 private:
98   static char PassID;
99 
100   int &Runs;
101 };
102 
103 char TestFunctionAnalysis::PassID;
104 
105 class TestImmutableFunctionAnalysis {
106 public:
107   struct Result {
108     bool invalidate(Function &, const PreservedAnalyses &) { return false; }
109   };
110 
111   static void *ID() { return (void *)&PassID; }
112   static StringRef name() { return "TestImmutableFunctionAnalysis"; }
113 
114   TestImmutableFunctionAnalysis(int &Runs) : Runs(Runs) {}
115 
116   Result run(Function &F, FunctionAnalysisManager &AM) {
117     ++Runs;
118     return Result();
119   }
120 
121 private:
122   static char PassID;
123 
124   int &Runs;
125 };
126 
127 char TestImmutableFunctionAnalysis::PassID;
128 
129 struct TestModulePass {
130   TestModulePass(int &RunCount) : RunCount(RunCount) {}
131 
132   PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM) {
133     ++RunCount;
134     (void)AM.getResult<TestModuleAnalysis>(M);
135     return PreservedAnalyses::all();
136   }
137 
138   static StringRef name() { return "TestModulePass"; }
139 
140   int &RunCount;
141 };
142 
143 struct TestSCCPass {
144   TestSCCPass(int &RunCount, int &AnalyzedInstrCount,
145                    int &AnalyzedSCCFunctionCount,
146                    int &AnalyzedModuleFunctionCount,
147                    bool OnlyUseCachedResults = false)
148       : RunCount(RunCount), AnalyzedInstrCount(AnalyzedInstrCount),
149         AnalyzedSCCFunctionCount(AnalyzedSCCFunctionCount),
150         AnalyzedModuleFunctionCount(AnalyzedModuleFunctionCount),
151         OnlyUseCachedResults(OnlyUseCachedResults) {}
152 
153   PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM) {
154     ++RunCount;
155 
156     const ModuleAnalysisManager &MAM =
157         AM.getResult<ModuleAnalysisManagerCGSCCProxy>(C).getManager();
158     FunctionAnalysisManager &FAM =
159         AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C).getManager();
160     if (TestModuleAnalysis::Result *TMA =
161             MAM.getCachedResult<TestModuleAnalysis>(
162                 *C.begin()->getFunction().getParent()))
163       AnalyzedModuleFunctionCount += TMA->FunctionCount;
164 
165     if (OnlyUseCachedResults) {
166       // Hack to force the use of the cached interface.
167       if (TestSCCAnalysis::Result *AR = AM.getCachedResult<TestSCCAnalysis>(C))
168         AnalyzedSCCFunctionCount += AR->FunctionCount;
169       for (LazyCallGraph::Node &N : C)
170         if (TestFunctionAnalysis::Result *FAR =
171                 FAM.getCachedResult<TestFunctionAnalysis>(N.getFunction()))
172           AnalyzedInstrCount += FAR->InstructionCount;
173     } else {
174       // Typical path just runs the analysis as needed.
175       TestSCCAnalysis::Result &AR = AM.getResult<TestSCCAnalysis>(C);
176       AnalyzedSCCFunctionCount += AR.FunctionCount;
177       for (LazyCallGraph::Node &N : C) {
178         TestFunctionAnalysis::Result &FAR =
179             FAM.getResult<TestFunctionAnalysis>(N.getFunction());
180         AnalyzedInstrCount += FAR.InstructionCount;
181 
182         // Just ensure we get the immutable results.
183         (void)FAM.getResult<TestImmutableFunctionAnalysis>(N.getFunction());
184       }
185     }
186 
187     return PreservedAnalyses::all();
188   }
189 
190   static StringRef name() { return "TestSCCPass"; }
191 
192   int &RunCount;
193   int &AnalyzedInstrCount;
194   int &AnalyzedSCCFunctionCount;
195   int &AnalyzedModuleFunctionCount;
196   bool OnlyUseCachedResults;
197 };
198 
199 struct TestFunctionPass {
200   TestFunctionPass(int &RunCount) : RunCount(RunCount) {}
201 
202   PreservedAnalyses run(Function &M) {
203     ++RunCount;
204     return PreservedAnalyses::none();
205   }
206 
207   static StringRef name() { return "TestFunctionPass"; }
208 
209   int &RunCount;
210 };
211 
212 std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
213   SMDiagnostic Err;
214   return parseAssemblyString(IR, Err, C);
215 }
216 
217 class CGSCCPassManagerTest : public ::testing::Test {
218 protected:
219   LLVMContext Context;
220   std::unique_ptr<Module> M;
221 
222 public:
223   CGSCCPassManagerTest()
224       : M(parseIR(Context, "define void @f() {\n"
225                            "entry:\n"
226                            "  call void @g()\n"
227                            "  call void @h1()\n"
228                            "  ret void\n"
229                            "}\n"
230                            "define void @g() {\n"
231                            "entry:\n"
232                            "  call void @g()\n"
233                            "  call void @x()\n"
234                            "  ret void\n"
235                            "}\n"
236                            "define void @h1() {\n"
237                            "entry:\n"
238                            "  call void @h2()\n"
239                            "  ret void\n"
240                            "}\n"
241                            "define void @h2() {\n"
242                            "entry:\n"
243                            "  call void @h3()\n"
244                            "  call void @x()\n"
245                            "  ret void\n"
246                            "}\n"
247                            "define void @h3() {\n"
248                            "entry:\n"
249                            "  call void @h1()\n"
250                            "  ret void\n"
251                            "}\n"
252                            "define void @x() {\n"
253                            "entry:\n"
254                            "  ret void\n"
255                            "}\n")) {}
256 };
257 
258 TEST_F(CGSCCPassManagerTest, Basic) {
259   FunctionAnalysisManager FAM(/*DebugLogging*/ true);
260   int FunctionAnalysisRuns = 0;
261   FAM.registerPass([&] { return TestFunctionAnalysis(FunctionAnalysisRuns); });
262   int ImmutableFunctionAnalysisRuns = 0;
263   FAM.registerPass([&] {
264     return TestImmutableFunctionAnalysis(ImmutableFunctionAnalysisRuns);
265   });
266 
267   CGSCCAnalysisManager CGAM(/*DebugLogging*/ true);
268   int SCCAnalysisRuns = 0;
269   CGAM.registerPass([&] { return TestSCCAnalysis(SCCAnalysisRuns); });
270 
271   ModuleAnalysisManager MAM(/*DebugLogging*/ true);
272   int ModuleAnalysisRuns = 0;
273   MAM.registerPass([&] { return LazyCallGraphAnalysis(); });
274   MAM.registerPass([&] { return TestModuleAnalysis(ModuleAnalysisRuns); });
275 
276   MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
277   MAM.registerPass([&] { return CGSCCAnalysisManagerModuleProxy(CGAM); });
278   CGAM.registerPass([&] { return FunctionAnalysisManagerCGSCCProxy(FAM); });
279   CGAM.registerPass([&] { return ModuleAnalysisManagerCGSCCProxy(MAM); });
280   FAM.registerPass([&] { return CGSCCAnalysisManagerFunctionProxy(CGAM); });
281   FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); });
282 
283   ModulePassManager MPM(/*DebugLogging*/ true);
284   int ModulePassRunCount1 = 0;
285   MPM.addPass(TestModulePass(ModulePassRunCount1));
286 
287   CGSCCPassManager CGPM1(/*DebugLogging*/ true);
288   int SCCPassRunCount1 = 0;
289   int AnalyzedInstrCount1 = 0;
290   int AnalyzedSCCFunctionCount1 = 0;
291   int AnalyzedModuleFunctionCount1 = 0;
292   CGPM1.addPass(TestSCCPass(SCCPassRunCount1, AnalyzedInstrCount1,
293                             AnalyzedSCCFunctionCount1,
294                             AnalyzedModuleFunctionCount1));
295 
296   FunctionPassManager FPM1(/*DebugLogging*/ true);
297   int FunctionPassRunCount1 = 0;
298   FPM1.addPass(TestFunctionPass(FunctionPassRunCount1));
299   CGPM1.addPass(createCGSCCToFunctionPassAdaptor(std::move(FPM1)));
300   MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM1)));
301 
302   MPM.run(*M, MAM);
303 
304   EXPECT_EQ(1, ModulePassRunCount1);
305 
306   EXPECT_EQ(1, ModuleAnalysisRuns);
307   EXPECT_EQ(4, SCCAnalysisRuns);
308   EXPECT_EQ(6, FunctionAnalysisRuns);
309   EXPECT_EQ(6, ImmutableFunctionAnalysisRuns);
310 
311   EXPECT_EQ(4, SCCPassRunCount1);
312   EXPECT_EQ(14, AnalyzedInstrCount1);
313   EXPECT_EQ(6, AnalyzedSCCFunctionCount1);
314   EXPECT_EQ(4 * 6, AnalyzedModuleFunctionCount1);
315 }
316 
317 }
318