1 //===- CGSCCPassManagerTest.cpp -------------------------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #include "llvm/Analysis/CGSCCPassManager.h"
11 #include "llvm/Analysis/LazyCallGraph.h"
12 #include "llvm/AsmParser/Parser.h"
13 #include "llvm/IR/Function.h"
14 #include "llvm/IR/InstIterator.h"
15 #include "llvm/IR/LLVMContext.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/IR/PassManager.h"
18 #include "llvm/Support/SourceMgr.h"
19 #include "gtest/gtest.h"
20 
21 using namespace llvm;
22 
23 namespace {
24 
25 class TestModuleAnalysis {
26 public:
27   struct Result {
28     Result(int Count) : FunctionCount(Count) {}
29     int FunctionCount;
30   };
31 
32   static void *ID() { return (void *)&PassID; }
33   static StringRef name() { return "TestModuleAnalysis"; }
34 
35   TestModuleAnalysis(int &Runs) : Runs(Runs) {}
36 
37   Result run(Module &M, ModuleAnalysisManager &AM) {
38     ++Runs;
39     return Result(M.size());
40   }
41 
42 private:
43   static char PassID;
44 
45   int &Runs;
46 };
47 
48 char TestModuleAnalysis::PassID;
49 
50 class TestSCCAnalysis {
51 public:
52   struct Result {
53     Result(int Count) : FunctionCount(Count) {}
54     int FunctionCount;
55   };
56 
57   static void *ID() { return (void *)&PassID; }
58   static StringRef name() { return "TestSCCAnalysis"; }
59 
60   TestSCCAnalysis(int &Runs) : Runs(Runs) {}
61 
62   Result run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM) {
63     ++Runs;
64     return Result(C.size());
65   }
66 
67 private:
68   static char PassID;
69 
70   int &Runs;
71 };
72 
73 char TestSCCAnalysis::PassID;
74 
75 class TestFunctionAnalysis {
76 public:
77   struct Result {
78     Result(int Count) : InstructionCount(Count) {}
79     int InstructionCount;
80   };
81 
82   static void *ID() { return (void *)&PassID; }
83   static StringRef name() { return "TestFunctionAnalysis"; }
84 
85   TestFunctionAnalysis(int &Runs) : Runs(Runs) {}
86 
87   Result run(Function &F, FunctionAnalysisManager &AM) {
88     ++Runs;
89     int Count = 0;
90     for (Instruction &I : instructions(F)) {
91       (void)I;
92       ++Count;
93     }
94     return Result(Count);
95   }
96 
97 private:
98   static char PassID;
99 
100   int &Runs;
101 };
102 
103 char TestFunctionAnalysis::PassID;
104 
105 class TestImmutableFunctionAnalysis {
106 public:
107   struct Result {
108     bool invalidate(Function &, const PreservedAnalyses &) { return false; }
109   };
110 
111   static void *ID() { return (void *)&PassID; }
112   static StringRef name() { return "TestImmutableFunctionAnalysis"; }
113 
114   TestImmutableFunctionAnalysis(int &Runs) : Runs(Runs) {}
115 
116   Result run(Function &F, FunctionAnalysisManager &AM) {
117     ++Runs;
118     return Result();
119   }
120 
121 private:
122   static char PassID;
123 
124   int &Runs;
125 };
126 
127 char TestImmutableFunctionAnalysis::PassID;
128 
129 struct TestModulePass {
130   TestModulePass(int &RunCount) : RunCount(RunCount) {}
131 
132   PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM) {
133     ++RunCount;
134     (void)AM.getResult<TestModuleAnalysis>(M);
135     return PreservedAnalyses::all();
136   }
137 
138   static StringRef name() { return "TestModulePass"; }
139 
140   int &RunCount;
141 };
142 
143 struct TestSCCPass {
144   TestSCCPass(int &RunCount, int &AnalyzedInstrCount,
145               int &AnalyzedSCCFunctionCount, int &AnalyzedModuleFunctionCount,
146               bool OnlyUseCachedResults = false)
147       : RunCount(RunCount), AnalyzedInstrCount(AnalyzedInstrCount),
148         AnalyzedSCCFunctionCount(AnalyzedSCCFunctionCount),
149         AnalyzedModuleFunctionCount(AnalyzedModuleFunctionCount),
150         OnlyUseCachedResults(OnlyUseCachedResults) {}
151 
152   PreservedAnalyses run(LazyCallGraph::SCC &C, CGSCCAnalysisManager &AM) {
153     ++RunCount;
154 
155     const ModuleAnalysisManager &MAM =
156         AM.getResult<ModuleAnalysisManagerCGSCCProxy>(C).getManager();
157     FunctionAnalysisManager &FAM =
158         AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C).getManager();
159     if (TestModuleAnalysis::Result *TMA =
160             MAM.getCachedResult<TestModuleAnalysis>(
161                 *C.begin()->getFunction().getParent()))
162       AnalyzedModuleFunctionCount += TMA->FunctionCount;
163 
164     if (OnlyUseCachedResults) {
165       // Hack to force the use of the cached interface.
166       if (TestSCCAnalysis::Result *AR = AM.getCachedResult<TestSCCAnalysis>(C))
167         AnalyzedSCCFunctionCount += AR->FunctionCount;
168       for (LazyCallGraph::Node &N : C)
169         if (TestFunctionAnalysis::Result *FAR =
170                 FAM.getCachedResult<TestFunctionAnalysis>(N.getFunction()))
171           AnalyzedInstrCount += FAR->InstructionCount;
172     } else {
173       // Typical path just runs the analysis as needed.
174       TestSCCAnalysis::Result &AR = AM.getResult<TestSCCAnalysis>(C);
175       AnalyzedSCCFunctionCount += AR.FunctionCount;
176       for (LazyCallGraph::Node &N : C) {
177         TestFunctionAnalysis::Result &FAR =
178             FAM.getResult<TestFunctionAnalysis>(N.getFunction());
179         AnalyzedInstrCount += FAR.InstructionCount;
180 
181         // Just ensure we get the immutable results.
182         (void)FAM.getResult<TestImmutableFunctionAnalysis>(N.getFunction());
183       }
184     }
185 
186     return PreservedAnalyses::all();
187   }
188 
189   static StringRef name() { return "TestSCCPass"; }
190 
191   int &RunCount;
192   int &AnalyzedInstrCount;
193   int &AnalyzedSCCFunctionCount;
194   int &AnalyzedModuleFunctionCount;
195   bool OnlyUseCachedResults;
196 };
197 
198 struct TestFunctionPass {
199   TestFunctionPass(int &RunCount) : RunCount(RunCount) {}
200 
201   PreservedAnalyses run(Function &F, AnalysisManager<Function> &) {
202     ++RunCount;
203     return PreservedAnalyses::none();
204   }
205 
206   static StringRef name() { return "TestFunctionPass"; }
207 
208   int &RunCount;
209 };
210 
211 std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
212   SMDiagnostic Err;
213   return parseAssemblyString(IR, Err, C);
214 }
215 
216 class CGSCCPassManagerTest : public ::testing::Test {
217 protected:
218   LLVMContext Context;
219   std::unique_ptr<Module> M;
220 
221 public:
222   CGSCCPassManagerTest()
223       : M(parseIR(Context, "define void @f() {\n"
224                            "entry:\n"
225                            "  call void @g()\n"
226                            "  call void @h1()\n"
227                            "  ret void\n"
228                            "}\n"
229                            "define void @g() {\n"
230                            "entry:\n"
231                            "  call void @g()\n"
232                            "  call void @x()\n"
233                            "  ret void\n"
234                            "}\n"
235                            "define void @h1() {\n"
236                            "entry:\n"
237                            "  call void @h2()\n"
238                            "  ret void\n"
239                            "}\n"
240                            "define void @h2() {\n"
241                            "entry:\n"
242                            "  call void @h3()\n"
243                            "  call void @x()\n"
244                            "  ret void\n"
245                            "}\n"
246                            "define void @h3() {\n"
247                            "entry:\n"
248                            "  call void @h1()\n"
249                            "  ret void\n"
250                            "}\n"
251                            "define void @x() {\n"
252                            "entry:\n"
253                            "  ret void\n"
254                            "}\n")) {}
255 };
256 
257 TEST_F(CGSCCPassManagerTest, Basic) {
258   FunctionAnalysisManager FAM(/*DebugLogging*/ true);
259   int FunctionAnalysisRuns = 0;
260   FAM.registerPass([&] { return TestFunctionAnalysis(FunctionAnalysisRuns); });
261   int ImmutableFunctionAnalysisRuns = 0;
262   FAM.registerPass([&] {
263     return TestImmutableFunctionAnalysis(ImmutableFunctionAnalysisRuns);
264   });
265 
266   CGSCCAnalysisManager CGAM(/*DebugLogging*/ true);
267   int SCCAnalysisRuns = 0;
268   CGAM.registerPass([&] { return TestSCCAnalysis(SCCAnalysisRuns); });
269 
270   ModuleAnalysisManager MAM(/*DebugLogging*/ true);
271   int ModuleAnalysisRuns = 0;
272   MAM.registerPass([&] { return LazyCallGraphAnalysis(); });
273   MAM.registerPass([&] { return TestModuleAnalysis(ModuleAnalysisRuns); });
274 
275   MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
276   MAM.registerPass([&] { return CGSCCAnalysisManagerModuleProxy(CGAM); });
277   CGAM.registerPass([&] { return FunctionAnalysisManagerCGSCCProxy(FAM); });
278   CGAM.registerPass([&] { return ModuleAnalysisManagerCGSCCProxy(MAM); });
279   FAM.registerPass([&] { return CGSCCAnalysisManagerFunctionProxy(CGAM); });
280   FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); });
281 
282   ModulePassManager MPM(/*DebugLogging*/ true);
283   int ModulePassRunCount1 = 0;
284   MPM.addPass(TestModulePass(ModulePassRunCount1));
285 
286   CGSCCPassManager CGPM1(/*DebugLogging*/ true);
287   int SCCPassRunCount1 = 0;
288   int AnalyzedInstrCount1 = 0;
289   int AnalyzedSCCFunctionCount1 = 0;
290   int AnalyzedModuleFunctionCount1 = 0;
291   CGPM1.addPass(TestSCCPass(SCCPassRunCount1, AnalyzedInstrCount1,
292                             AnalyzedSCCFunctionCount1,
293                             AnalyzedModuleFunctionCount1));
294 
295   FunctionPassManager FPM1(/*DebugLogging*/ true);
296   int FunctionPassRunCount1 = 0;
297   FPM1.addPass(TestFunctionPass(FunctionPassRunCount1));
298   CGPM1.addPass(createCGSCCToFunctionPassAdaptor(std::move(FPM1)));
299   MPM.addPass(createModuleToPostOrderCGSCCPassAdaptor(std::move(CGPM1)));
300 
301   MPM.run(*M, MAM);
302 
303   EXPECT_EQ(1, ModulePassRunCount1);
304 
305   EXPECT_EQ(1, ModuleAnalysisRuns);
306   EXPECT_EQ(4, SCCAnalysisRuns);
307   EXPECT_EQ(6, FunctionAnalysisRuns);
308   EXPECT_EQ(6, ImmutableFunctionAnalysisRuns);
309 
310   EXPECT_EQ(4, SCCPassRunCount1);
311   EXPECT_EQ(14, AnalyzedInstrCount1);
312   EXPECT_EQ(6, AnalyzedSCCFunctionCount1);
313   EXPECT_EQ(4 * 6, AnalyzedModuleFunctionCount1);
314 }
315 
316 }
317