1 //===------ LoopGenerators.cpp -  IR helper to create loops ---------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file contains functions to create scalar and OpenMP parallel loops
11 // as LLVM-IR.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "polly/ScopDetection.h"
16 #include "polly/CodeGen/LoopGenerators.h"
17 #include "llvm/Analysis/LoopInfo.h"
18 #include "llvm/IR/DataLayout.h"
19 #include "llvm/IR/Dominators.h"
20 #include "llvm/IR/Module.h"
21 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
22 
23 using namespace llvm;
24 using namespace polly;
25 
26 // We generate a loop of the following structure
27 //
28 //              BeforeBB
29 //                 |
30 //                 v
31 //              GuardBB
32 //              /      |
33 //     __  PreHeaderBB  |
34 //    /  \    /         |
35 // latch  HeaderBB      |
36 //    \  /    \         /
37 //     <       \       /
38 //              \     /
39 //              ExitBB
40 //
41 // GuardBB checks if the loop is executed at least once. If this is the case
42 // we branch to PreHeaderBB and subsequently to the HeaderBB, which contains the
43 // loop iv 'polly.indvar', the incremented loop iv 'polly.indvar_next' as well
44 // as the condition to check if we execute another iteration of the loop. After
45 // the loop has finished, we branch to ExitBB.
46 //
47 // TODO: We currently always create the GuardBB. If we can prove the loop is
48 //       always executed at least once, we can get rid of this branch.
49 Value *polly::createLoop(Value *LB, Value *UB, Value *Stride,
50                          PollyIRBuilder &Builder, Pass *P, BasicBlock *&ExitBB,
51                          ICmpInst::Predicate Predicate,
52                          LoopAnnotator *Annotator, bool Parallel) {
53   DominatorTree &DT = P->getAnalysis<DominatorTreeWrapperPass>().getDomTree();
54   LoopInfo &LI = P->getAnalysis<LoopInfo>();
55   Function *F = Builder.GetInsertBlock()->getParent();
56   LLVMContext &Context = F->getContext();
57 
58   assert(LB->getType() == UB->getType() && "Types of loop bounds do not match");
59   IntegerType *LoopIVType = dyn_cast<IntegerType>(UB->getType());
60   assert(LoopIVType && "UB is not integer?");
61 
62   BasicBlock *BeforeBB = Builder.GetInsertBlock();
63   BasicBlock *GuardBB = BasicBlock::Create(Context, "polly.loop_if", F);
64   BasicBlock *HeaderBB = BasicBlock::Create(Context, "polly.loop_header", F);
65   BasicBlock *PreHeaderBB =
66       BasicBlock::Create(Context, "polly.loop_preheader", F);
67 
68   if (Annotator) {
69     Annotator->Begin(HeaderBB);
70     if (Parallel)
71       Annotator->SetCurrentParallel();
72   }
73 
74   // Update LoopInfo
75   Loop *OuterLoop = LI.getLoopFor(BeforeBB);
76   Loop *NewLoop = new Loop();
77 
78   if (OuterLoop) {
79     OuterLoop->addChildLoop(NewLoop);
80   } else {
81     LI.addTopLevelLoop(NewLoop);
82   }
83 
84   if (OuterLoop) {
85     OuterLoop->addBasicBlockToLoop(GuardBB, LI.getBase());
86     OuterLoop->addBasicBlockToLoop(PreHeaderBB, LI.getBase());
87   }
88 
89   NewLoop->addBasicBlockToLoop(HeaderBB, LI.getBase());
90 
91   // ExitBB
92   ExitBB = SplitBlock(BeforeBB, Builder.GetInsertPoint()++, P);
93   ExitBB->setName("polly.loop_exit");
94 
95   // BeforeBB
96   BeforeBB->getTerminator()->setSuccessor(0, GuardBB);
97 
98   // GuardBB
99   DT.addNewBlock(GuardBB, BeforeBB);
100   Builder.SetInsertPoint(GuardBB);
101   Value *LoopGuard;
102   LoopGuard = Builder.CreateICmp(Predicate, LB, UB);
103   LoopGuard->setName("polly.loop_guard");
104   Builder.CreateCondBr(LoopGuard, PreHeaderBB, ExitBB);
105 
106   // PreHeaderBB
107   DT.addNewBlock(PreHeaderBB, GuardBB);
108   Builder.SetInsertPoint(PreHeaderBB);
109   Builder.CreateBr(HeaderBB);
110 
111   // HeaderBB
112   DT.addNewBlock(HeaderBB, PreHeaderBB);
113   Builder.SetInsertPoint(HeaderBB);
114   PHINode *IV = Builder.CreatePHI(LoopIVType, 2, "polly.indvar");
115   IV->addIncoming(LB, PreHeaderBB);
116   Stride = Builder.CreateZExtOrBitCast(Stride, LoopIVType);
117   Value *IncrementedIV = Builder.CreateNSWAdd(IV, Stride, "polly.indvar_next");
118   Value *LoopCondition;
119   UB = Builder.CreateSub(UB, Stride, "polly.adjust_ub");
120   LoopCondition = Builder.CreateICmp(Predicate, IV, UB);
121   LoopCondition->setName("polly.loop_cond");
122   Builder.CreateCondBr(LoopCondition, HeaderBB, ExitBB);
123   IV->addIncoming(IncrementedIV, HeaderBB);
124   DT.changeImmediateDominator(ExitBB, GuardBB);
125 
126   // The loop body should be added here.
127   Builder.SetInsertPoint(HeaderBB->getFirstNonPHI());
128   return IV;
129 }
130 
131 void OMPGenerator::createCallParallelLoopStart(
132     Value *SubFunction, Value *SubfunctionParam, Value *NumberOfThreads,
133     Value *LowerBound, Value *UpperBound, Value *Stride) {
134   Module *M = getModule();
135   const char *Name = "GOMP_parallel_loop_runtime_start";
136   Function *F = M->getFunction(Name);
137 
138   // If F is not available, declare it.
139   if (!F) {
140     Type *LongTy = getIntPtrTy();
141     GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
142 
143     Type *Params[] = {PointerType::getUnqual(FunctionType::get(
144                           Builder.getVoidTy(), Builder.getInt8PtrTy(), false)),
145                       Builder.getInt8PtrTy(), Builder.getInt32Ty(), LongTy,
146                       LongTy, LongTy};
147 
148     FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), Params, false);
149     F = Function::Create(Ty, Linkage, Name, M);
150   }
151 
152   Value *Args[] = {SubFunction, SubfunctionParam, NumberOfThreads,
153                    LowerBound,  UpperBound,       Stride};
154 
155   Builder.CreateCall(F, Args);
156 }
157 
158 Value *OMPGenerator::createCallLoopNext(Value *LowerBoundPtr,
159                                         Value *UpperBoundPtr) {
160   Module *M = getModule();
161   const char *Name = "GOMP_loop_runtime_next";
162   Function *F = M->getFunction(Name);
163 
164   // If F is not available, declare it.
165   if (!F) {
166     Type *LongPtrTy = PointerType::getUnqual(getIntPtrTy());
167     GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
168 
169     Type *Params[] = {LongPtrTy, LongPtrTy};
170 
171     FunctionType *Ty = FunctionType::get(Builder.getInt8Ty(), Params, false);
172     F = Function::Create(Ty, Linkage, Name, M);
173   }
174 
175   Value *Args[] = {LowerBoundPtr, UpperBoundPtr};
176 
177   Value *Return = Builder.CreateCall(F, Args);
178   Return = Builder.CreateICmpNE(
179       Return, Builder.CreateZExt(Builder.getFalse(), Return->getType()));
180   return Return;
181 }
182 
183 void OMPGenerator::createCallParallelEnd() {
184   const char *Name = "GOMP_parallel_end";
185   Module *M = getModule();
186   Function *F = M->getFunction(Name);
187 
188   // If F is not available, declare it.
189   if (!F) {
190     GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
191 
192     FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), false);
193     F = Function::Create(Ty, Linkage, Name, M);
194   }
195 
196   Builder.CreateCall(F);
197 }
198 
199 void OMPGenerator::createCallLoopEndNowait() {
200   const char *Name = "GOMP_loop_end_nowait";
201   Module *M = getModule();
202   Function *F = M->getFunction(Name);
203 
204   // If F is not available, declare it.
205   if (!F) {
206     GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
207 
208     FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), false);
209     F = Function::Create(Ty, Linkage, Name, M);
210   }
211 
212   Builder.CreateCall(F);
213 }
214 
215 IntegerType *OMPGenerator::getIntPtrTy() {
216   return P->getAnalysis<DataLayoutPass>().getDataLayout().getIntPtrType(
217       Builder.getContext());
218 }
219 
220 Module *OMPGenerator::getModule() {
221   return Builder.GetInsertBlock()->getParent()->getParent();
222 }
223 
224 Function *OMPGenerator::createSubfunctionDefinition() {
225   Module *M = getModule();
226   Function *F = Builder.GetInsertBlock()->getParent();
227   std::vector<Type *> Arguments(1, Builder.getInt8PtrTy());
228   FunctionType *FT = FunctionType::get(Builder.getVoidTy(), Arguments, false);
229   Function *FN = Function::Create(FT, Function::InternalLinkage,
230                                   F->getName() + ".omp_subfn", M);
231   // Do not run any polly pass on the new function.
232   P->getAnalysis<polly::ScopDetection>().markFunctionAsInvalid(FN);
233 
234   Function::arg_iterator AI = FN->arg_begin();
235   AI->setName("omp.userContext");
236 
237   return FN;
238 }
239 
240 Value *OMPGenerator::loadValuesIntoStruct(SetVector<Value *> &Values) {
241   std::vector<Type *> Members;
242 
243   for (unsigned i = 0; i < Values.size(); i++)
244     Members.push_back(Values[i]->getType());
245 
246   StructType *Ty = StructType::get(Builder.getContext(), Members);
247   Value *Struct = Builder.CreateAlloca(Ty, 0, "omp.userContext");
248 
249   for (unsigned i = 0; i < Values.size(); i++) {
250     Value *Address = Builder.CreateStructGEP(Struct, i);
251     Builder.CreateStore(Values[i], Address);
252   }
253 
254   return Struct;
255 }
256 
257 void OMPGenerator::extractValuesFromStruct(SetVector<Value *> OldValues,
258                                            Value *Struct,
259                                            ValueToValueMapTy &Map) {
260   for (unsigned i = 0; i < OldValues.size(); i++) {
261     Value *Address = Builder.CreateStructGEP(Struct, i);
262     Value *NewValue = Builder.CreateLoad(Address);
263     Map.insert(std::make_pair(OldValues[i], NewValue));
264   }
265 }
266 
267 Value *OMPGenerator::createSubfunction(Value *Stride, Value *StructData,
268                                        SetVector<Value *> Data,
269                                        ValueToValueMapTy &Map,
270                                        Function **SubFunction) {
271   Function *FN = createSubfunctionDefinition();
272 
273   BasicBlock *PrevBB, *HeaderBB, *ExitBB, *CheckNextBB, *LoadIVBoundsBB,
274       *AfterBB;
275   Value *LowerBoundPtr, *UpperBoundPtr, *UserContext, *Ret1, *HasNextSchedule,
276       *LowerBound, *UpperBound, *IV;
277   Type *IntPtrTy = getIntPtrTy();
278   LLVMContext &Context = FN->getContext();
279 
280   // Store the previous basic block.
281   PrevBB = Builder.GetInsertBlock();
282 
283   // Create basic blocks.
284   HeaderBB = BasicBlock::Create(Context, "omp.setup", FN);
285   ExitBB = BasicBlock::Create(Context, "omp.exit", FN);
286   CheckNextBB = BasicBlock::Create(Context, "omp.checkNext", FN);
287   LoadIVBoundsBB = BasicBlock::Create(Context, "omp.loadIVBounds", FN);
288 
289   DominatorTree &DT = P->getAnalysis<DominatorTreeWrapperPass>().getDomTree();
290   DT.addNewBlock(HeaderBB, PrevBB);
291   DT.addNewBlock(ExitBB, HeaderBB);
292   DT.addNewBlock(CheckNextBB, HeaderBB);
293   DT.addNewBlock(LoadIVBoundsBB, HeaderBB);
294 
295   // Fill up basic block HeaderBB.
296   Builder.SetInsertPoint(HeaderBB);
297   LowerBoundPtr = Builder.CreateAlloca(IntPtrTy, 0, "omp.lowerBoundPtr");
298   UpperBoundPtr = Builder.CreateAlloca(IntPtrTy, 0, "omp.upperBoundPtr");
299   UserContext = Builder.CreateBitCast(FN->arg_begin(), StructData->getType(),
300                                       "omp.userContext");
301 
302   extractValuesFromStruct(Data, UserContext, Map);
303   Builder.CreateBr(CheckNextBB);
304 
305   // Add code to check if another set of iterations will be executed.
306   Builder.SetInsertPoint(CheckNextBB);
307   Ret1 = createCallLoopNext(LowerBoundPtr, UpperBoundPtr);
308   HasNextSchedule = Builder.CreateTrunc(Ret1, Builder.getInt1Ty(),
309                                         "omp.hasNextScheduleBlock");
310   Builder.CreateCondBr(HasNextSchedule, LoadIVBoundsBB, ExitBB);
311 
312   // Add code to to load the iv bounds for this set of iterations.
313   Builder.SetInsertPoint(LoadIVBoundsBB);
314   LowerBound = Builder.CreateLoad(LowerBoundPtr, "omp.lowerBound");
315   UpperBound = Builder.CreateLoad(UpperBoundPtr, "omp.upperBound");
316 
317   // Subtract one as the upper bound provided by openmp is a < comparison
318   // whereas the codegenForSequential function creates a <= comparison.
319   UpperBound = Builder.CreateSub(UpperBound, ConstantInt::get(IntPtrTy, 1),
320                                  "omp.upperBoundAdjusted");
321 
322   Builder.CreateBr(CheckNextBB);
323   Builder.SetInsertPoint(--Builder.GetInsertPoint());
324   IV = createLoop(LowerBound, UpperBound, Stride, Builder, P, AfterBB,
325                   ICmpInst::ICMP_SLE);
326 
327   BasicBlock::iterator LoopBody = Builder.GetInsertPoint();
328   Builder.SetInsertPoint(AfterBB->begin());
329 
330   // Add code to terminate this openmp subfunction.
331   Builder.SetInsertPoint(ExitBB);
332   createCallLoopEndNowait();
333   Builder.CreateRetVoid();
334 
335   Builder.SetInsertPoint(LoopBody);
336   *SubFunction = FN;
337 
338   return IV;
339 }
340 
341 Value *OMPGenerator::createParallelLoop(Value *LowerBound, Value *UpperBound,
342                                         Value *Stride,
343                                         SetVector<Value *> &Values,
344                                         ValueToValueMapTy &Map,
345                                         BasicBlock::iterator *LoopBody) {
346   Value *Struct, *IV, *SubfunctionParam, *NumberOfThreads;
347   Function *SubFunction;
348 
349   Struct = loadValuesIntoStruct(Values);
350 
351   BasicBlock::iterator PrevInsertPoint = Builder.GetInsertPoint();
352   IV = createSubfunction(Stride, Struct, Values, Map, &SubFunction);
353   *LoopBody = Builder.GetInsertPoint();
354   Builder.SetInsertPoint(PrevInsertPoint);
355 
356   // Create call for GOMP_parallel_loop_runtime_start.
357   SubfunctionParam =
358       Builder.CreateBitCast(Struct, Builder.getInt8PtrTy(), "omp_data");
359 
360   NumberOfThreads = Builder.getInt32(0);
361 
362   // Add one as the upper bound provided by openmp is a < comparison
363   // whereas the codegenForSequential function creates a <= comparison.
364   UpperBound =
365       Builder.CreateAdd(UpperBound, ConstantInt::get(getIntPtrTy(), 1));
366 
367   createCallParallelLoopStart(SubFunction, SubfunctionParam, NumberOfThreads,
368                               LowerBound, UpperBound, Stride);
369   Builder.CreateCall(SubFunction, SubfunctionParam);
370   createCallParallelEnd();
371 
372   return IV;
373 }
374