1 //===- CodeExtractor.cpp - Unit tests for CodeExtractor -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Utils/CodeExtractor.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/Analysis/AssumptionCache.h"
12 #include "llvm/IR/BasicBlock.h"
13 #include "llvm/IR/Dominators.h"
14 #include "llvm/IR/Instructions.h"
15 #include "llvm/IR/LLVMContext.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/IR/Verifier.h"
18 #include "llvm/IRReader/IRReader.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "gtest/gtest.h"
21 
22 using namespace llvm;
23 
24 namespace {
25 BasicBlock *getBlockByName(Function *F, StringRef name) {
26   for (auto &BB : *F)
27     if (BB.getName() == name)
28       return &BB;
29   return nullptr;
30 }
31 
32 TEST(CodeExtractor, ExitStub) {
33   LLVMContext Ctx;
34   SMDiagnostic Err;
35   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
36     define i32 @foo(i32 %x, i32 %y, i32 %z) {
37     header:
38       %0 = icmp ugt i32 %x, %y
39       br i1 %0, label %body1, label %body2
40 
41     body1:
42       %1 = add i32 %z, 2
43       br label %notExtracted
44 
45     body2:
46       %2 = mul i32 %z, 7
47       br label %notExtracted
48 
49     notExtracted:
50       %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
51       %4 = add i32 %3, %x
52       ret i32 %4
53     }
54   )invalid",
55                                                 Err, Ctx));
56 
57   Function *Func = M->getFunction("foo");
58   SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "header"),
59                                            getBlockByName(Func, "body1"),
60                                            getBlockByName(Func, "body2") };
61 
62   CodeExtractor CE(Candidates);
63   EXPECT_TRUE(CE.isEligible());
64 
65   CodeExtractorAnalysisCache CEAC(*Func);
66   Function *Outlined = CE.extractCodeRegion(CEAC);
67   EXPECT_TRUE(Outlined);
68   BasicBlock *Exit = getBlockByName(Func, "notExtracted");
69   BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split");
70   // Ensure that PHI in exit block has only one incoming value (from code
71   // replacer block).
72   EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1);
73   // Ensure that there is a PHI in outlined function with 2 incoming values.
74   EXPECT_TRUE(ExitSplit &&
75               cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2);
76   EXPECT_FALSE(verifyFunction(*Outlined));
77   EXPECT_FALSE(verifyFunction(*Func));
78 }
79 
80 TEST(CodeExtractor, InputOutputMonitoring) {
81   LLVMContext Ctx;
82   SMDiagnostic Err;
83   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
84     define i32 @foo(i32 %x, i32 %y, i32 %z) {
85     header:
86       %0 = icmp ugt i32 %x, %y
87       br i1 %0, label %body1, label %body2
88 
89     body1:
90       %1 = add i32 %z, 2
91       br label %notExtracted
92 
93     body2:
94       %2 = mul i32 %z, 7
95       br label %notExtracted
96 
97     notExtracted:
98       %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
99       %4 = add i32 %3, %x
100       ret i32 %4
101     }
102   )invalid",
103                                                 Err, Ctx));
104 
105   Function *Func = M->getFunction("foo");
106   SmallVector<BasicBlock *, 3> Candidates{getBlockByName(Func, "header"),
107                                           getBlockByName(Func, "body1"),
108                                           getBlockByName(Func, "body2")};
109 
110   CodeExtractor CE(Candidates);
111   EXPECT_TRUE(CE.isEligible());
112 
113   CodeExtractorAnalysisCache CEAC(*Func);
114   SetVector<Value *> Inputs, Outputs;
115   Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
116   EXPECT_TRUE(Outlined);
117 
118   EXPECT_EQ(Inputs.size(), 3u);
119   EXPECT_EQ(Inputs[0], Func->getArg(2));
120   EXPECT_EQ(Inputs[1], Func->getArg(0));
121   EXPECT_EQ(Inputs[2], Func->getArg(1));
122   EXPECT_EQ(Outputs.size(), 1u);
123   StoreInst *SI = cast<StoreInst>(Outlined->getArg(3)->user_back());
124   Value *OutputVal = SI->getValueOperand();
125   EXPECT_EQ(Outputs[0], OutputVal);
126   BasicBlock *Exit = getBlockByName(Func, "notExtracted");
127   BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split");
128   // Ensure that PHI in exit block has only one incoming value (from code
129   // replacer block).
130   EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1);
131   // Ensure that there is a PHI in outlined function with 2 incoming values.
132   EXPECT_TRUE(ExitSplit &&
133               cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2);
134   EXPECT_FALSE(verifyFunction(*Outlined));
135   EXPECT_FALSE(verifyFunction(*Func));
136 }
137 
138 TEST(CodeExtractor, ExitPHIOnePredFromRegion) {
139   LLVMContext Ctx;
140   SMDiagnostic Err;
141   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
142     define i32 @foo() {
143     header:
144       br i1 undef, label %extracted1, label %pred
145 
146     pred:
147       br i1 undef, label %exit1, label %exit2
148 
149     extracted1:
150       br i1 undef, label %extracted2, label %exit1
151 
152     extracted2:
153       br label %exit2
154 
155     exit1:
156       %0 = phi i32 [ 1, %extracted1 ], [ 2, %pred ]
157       ret i32 %0
158 
159     exit2:
160       %1 = phi i32 [ 3, %extracted2 ], [ 4, %pred ]
161       ret i32 %1
162     }
163   )invalid", Err, Ctx));
164 
165   Function *Func = M->getFunction("foo");
166   SmallVector<BasicBlock *, 2> ExtractedBlocks{
167     getBlockByName(Func, "extracted1"),
168     getBlockByName(Func, "extracted2")
169   };
170 
171   CodeExtractor CE(ExtractedBlocks);
172   EXPECT_TRUE(CE.isEligible());
173 
174   CodeExtractorAnalysisCache CEAC(*Func);
175   Function *Outlined = CE.extractCodeRegion(CEAC);
176   EXPECT_TRUE(Outlined);
177   BasicBlock *Exit1 = getBlockByName(Func, "exit1");
178   BasicBlock *Exit2 = getBlockByName(Func, "exit2");
179   // Ensure that PHIs in exits are not splitted (since that they have only one
180   // incoming value from extracted region).
181   EXPECT_TRUE(Exit1 &&
182           cast<PHINode>(Exit1->front()).getNumIncomingValues() == 2);
183   EXPECT_TRUE(Exit2 &&
184           cast<PHINode>(Exit2->front()).getNumIncomingValues() == 2);
185   EXPECT_FALSE(verifyFunction(*Outlined));
186   EXPECT_FALSE(verifyFunction(*Func));
187 }
188 
189 TEST(CodeExtractor, StoreOutputInvokeResultAfterEHPad) {
190   LLVMContext Ctx;
191   SMDiagnostic Err;
192   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
193     declare i8 @hoge()
194 
195     define i32 @foo() personality i8* null {
196       entry:
197         %call = invoke i8 @hoge()
198                 to label %invoke.cont unwind label %lpad
199 
200       invoke.cont:                                      ; preds = %entry
201         unreachable
202 
203       lpad:                                             ; preds = %entry
204         %0 = landingpad { i8*, i32 }
205                 catch i8* null
206         br i1 undef, label %catch, label %finally.catchall
207 
208       catch:                                            ; preds = %lpad
209         %call2 = invoke i8 @hoge()
210                 to label %invoke.cont2 unwind label %lpad2
211 
212       invoke.cont2:                                    ; preds = %catch
213         %call3 = invoke i8 @hoge()
214                 to label %invoke.cont3 unwind label %lpad2
215 
216       invoke.cont3:                                    ; preds = %invoke.cont2
217         unreachable
218 
219       lpad2:                                           ; preds = %invoke.cont2, %catch
220         %ex.1 = phi i8* [ undef, %invoke.cont2 ], [ null, %catch ]
221         %1 = landingpad { i8*, i32 }
222                 catch i8* null
223         br label %finally.catchall
224 
225       finally.catchall:                                 ; preds = %lpad33, %lpad
226         %ex.2 = phi i8* [ %ex.1, %lpad2 ], [ null, %lpad ]
227         unreachable
228     }
229   )invalid", Err, Ctx));
230 
231 	if (!M) {
232     Err.print("unit", errs());
233     exit(1);
234   }
235 
236   Function *Func = M->getFunction("foo");
237   EXPECT_FALSE(verifyFunction(*Func, &errs()));
238 
239   SmallVector<BasicBlock *, 2> ExtractedBlocks{
240     getBlockByName(Func, "catch"),
241     getBlockByName(Func, "invoke.cont2"),
242     getBlockByName(Func, "invoke.cont3"),
243     getBlockByName(Func, "lpad2")
244   };
245 
246   CodeExtractor CE(ExtractedBlocks);
247   EXPECT_TRUE(CE.isEligible());
248 
249   CodeExtractorAnalysisCache CEAC(*Func);
250   Function *Outlined = CE.extractCodeRegion(CEAC);
251   EXPECT_TRUE(Outlined);
252   EXPECT_FALSE(verifyFunction(*Outlined, &errs()));
253   EXPECT_FALSE(verifyFunction(*Func, &errs()));
254 }
255 
256 TEST(CodeExtractor, StoreOutputInvokeResultInExitStub) {
257   LLVMContext Ctx;
258   SMDiagnostic Err;
259   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
260     declare i32 @bar()
261 
262     define i32 @foo() personality i8* null {
263     entry:
264       %0 = invoke i32 @bar() to label %exit unwind label %lpad
265 
266     exit:
267       ret i32 %0
268 
269     lpad:
270       %1 = landingpad { i8*, i32 }
271               cleanup
272       resume { i8*, i32 } %1
273     }
274   )invalid",
275                                                 Err, Ctx));
276 
277   Function *Func = M->getFunction("foo");
278   SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "entry"),
279                                        getBlockByName(Func, "lpad") };
280 
281   CodeExtractor CE(Blocks);
282   EXPECT_TRUE(CE.isEligible());
283 
284   CodeExtractorAnalysisCache CEAC(*Func);
285   Function *Outlined = CE.extractCodeRegion(CEAC);
286   EXPECT_TRUE(Outlined);
287   EXPECT_FALSE(verifyFunction(*Outlined));
288   EXPECT_FALSE(verifyFunction(*Func));
289 }
290 
291 TEST(CodeExtractor, ExtractAndInvalidateAssumptionCache) {
292   LLVMContext Ctx;
293   SMDiagnostic Err;
294   std::unique_ptr<Module> M(parseAssemblyString(R"ir(
295         target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
296         target triple = "aarch64"
297 
298         %b = type { i64 }
299         declare void @g(i8*)
300 
301         declare void @llvm.assume(i1) #0
302 
303         define void @test() {
304         entry:
305           br label %label
306 
307         label:
308           %0 = load %b*, %b** inttoptr (i64 8 to %b**), align 8
309           %1 = getelementptr inbounds %b, %b* %0, i64 undef, i32 0
310           %2 = load i64, i64* %1, align 8
311           %3 = icmp ugt i64 %2, 1
312           br i1 %3, label %if.then, label %if.else
313 
314         if.then:
315           unreachable
316 
317         if.else:
318           call void @g(i8* undef)
319           store i64 undef, i64* null, align 536870912
320           %4 = icmp eq i64 %2, 0
321           call void @llvm.assume(i1 %4)
322           unreachable
323         }
324 
325         attributes #0 = { nounwind willreturn }
326   )ir",
327                                                 Err, Ctx));
328 
329   assert(M && "Could not parse module?");
330   Function *Func = M->getFunction("test");
331   SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "if.else") };
332   AssumptionCache AC(*Func);
333   CodeExtractor CE(Blocks, nullptr, false, nullptr, nullptr, &AC);
334   EXPECT_TRUE(CE.isEligible());
335 
336   CodeExtractorAnalysisCache CEAC(*Func);
337   Function *Outlined = CE.extractCodeRegion(CEAC);
338   EXPECT_TRUE(Outlined);
339   EXPECT_FALSE(verifyFunction(*Outlined));
340   EXPECT_FALSE(verifyFunction(*Func));
341   EXPECT_FALSE(CE.verifyAssumptionCache(*Func, *Outlined, &AC));
342 }
343 
344 TEST(CodeExtractor, RemoveBitcastUsesFromOuterLifetimeMarkers) {
345   LLVMContext Ctx;
346   SMDiagnostic Err;
347   std::unique_ptr<Module> M(parseAssemblyString(R"ir(
348     target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
349     target triple = "x86_64-unknown-linux-gnu"
350 
351     declare void @use(i32*)
352     declare void @llvm.lifetime.start.p0i8(i64, i8*)
353     declare void @llvm.lifetime.end.p0i8(i64, i8*)
354 
355     define void @foo() {
356     entry:
357       %0 = alloca i32
358       br label %extract
359 
360     extract:
361       %1 = bitcast i32* %0 to i8*
362       call void @llvm.lifetime.start.p0i8(i64 4, i8* %1)
363       call void @use(i32* %0)
364       br label %exit
365 
366     exit:
367       call void @use(i32* %0)
368       call void @llvm.lifetime.end.p0i8(i64 4, i8* %1)
369       ret void
370     }
371   )ir",
372                                                 Err, Ctx));
373 
374   Function *Func = M->getFunction("foo");
375   SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
376 
377   CodeExtractor CE(Blocks);
378   EXPECT_TRUE(CE.isEligible());
379 
380   CodeExtractorAnalysisCache CEAC(*Func);
381   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
382   BasicBlock *CommonExit = nullptr;
383   CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
384   CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
385   EXPECT_EQ(Outputs.size(), 0U);
386 
387   Function *Outlined = CE.extractCodeRegion(CEAC);
388   EXPECT_TRUE(Outlined);
389   EXPECT_FALSE(verifyFunction(*Outlined));
390   EXPECT_FALSE(verifyFunction(*Func));
391 }
392 } // end anonymous namespace
393