1 //===------ LoopGenerators.cpp -  IR helper to create loops ---------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file contains functions to create scalar and OpenMP parallel loops
11 // as LLVM-IR.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "polly/ScopDetection.h"
16 #include "polly/CodeGen/LoopGenerators.h"
17 #include "llvm/Analysis/LoopInfo.h"
18 #include "llvm/IR/DataLayout.h"
19 #include "llvm/IR/Dominators.h"
20 #include "llvm/IR/Module.h"
21 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
22 
23 using namespace llvm;
24 using namespace polly;
25 
26 // We generate a loop of the following structure
27 //
28 //              BeforeBB
29 //                 |
30 //                 v
31 //              GuardBB
32 //              /      |
33 //     __  PreHeaderBB  |
34 //    /  \    /         |
35 // latch  HeaderBB      |
36 //    \  /    \         /
37 //     <       \       /
38 //              \     /
39 //              ExitBB
40 //
41 // GuardBB checks if the loop is executed at least once. If this is the case
42 // we branch to PreHeaderBB and subsequently to the HeaderBB, which contains the
43 // loop iv 'polly.indvar', the incremented loop iv 'polly.indvar_next' as well
44 // as the condition to check if we execute another iteration of the loop. After
45 // the loop has finished, we branch to ExitBB.
46 //
47 // TODO: We currently always create the GuardBB. If we can prove the loop is
48 //       always executed at least once, we can get rid of this branch.
49 Value *polly::createLoop(Value *LB, Value *UB, Value *Stride,
50                          PollyIRBuilder &Builder, Pass *P, BasicBlock *&ExitBB,
51                          ICmpInst::Predicate Predicate,
52                          LoopAnnotator *Annotator, bool Parallel) {
53 
54   DominatorTree &DT = P->getAnalysis<DominatorTreeWrapperPass>().getDomTree();
55   LoopInfo &LI = P->getAnalysis<LoopInfo>();
56   Function *F = Builder.GetInsertBlock()->getParent();
57   LLVMContext &Context = F->getContext();
58 
59   assert(LB->getType() == UB->getType() && "Types of loop bounds do not match");
60   IntegerType *LoopIVType = dyn_cast<IntegerType>(UB->getType());
61   assert(LoopIVType && "UB is not integer?");
62 
63   BasicBlock *BeforeBB = Builder.GetInsertBlock();
64   BasicBlock *GuardBB = BasicBlock::Create(Context, "polly.loop_if", F);
65   BasicBlock *HeaderBB = BasicBlock::Create(Context, "polly.loop_header", F);
66   BasicBlock *PreHeaderBB =
67       BasicBlock::Create(Context, "polly.loop_preheader", F);
68 
69   if (Annotator) {
70     Annotator->Begin(HeaderBB);
71     if (Parallel)
72       Annotator->SetCurrentParallel();
73   }
74 
75   // Update LoopInfo
76   Loop *OuterLoop = LI.getLoopFor(BeforeBB);
77   Loop *NewLoop = new Loop();
78 
79   if (OuterLoop) {
80     OuterLoop->addChildLoop(NewLoop);
81   } else {
82     LI.addTopLevelLoop(NewLoop);
83   }
84 
85   if (OuterLoop) {
86     OuterLoop->addBasicBlockToLoop(GuardBB, LI.getBase());
87     OuterLoop->addBasicBlockToLoop(PreHeaderBB, LI.getBase());
88   }
89 
90   NewLoop->addBasicBlockToLoop(HeaderBB, LI.getBase());
91 
92   // ExitBB
93   ExitBB = SplitBlock(BeforeBB, Builder.GetInsertPoint()++, P);
94   ExitBB->setName("polly.loop_exit");
95 
96   // BeforeBB
97   BeforeBB->getTerminator()->setSuccessor(0, GuardBB);
98 
99   // GuardBB
100   DT.addNewBlock(GuardBB, BeforeBB);
101   Builder.SetInsertPoint(GuardBB);
102   Value *LoopGuard;
103   LoopGuard = Builder.CreateICmp(Predicate, LB, UB);
104   LoopGuard->setName("polly.loop_guard");
105   Builder.CreateCondBr(LoopGuard, PreHeaderBB, ExitBB);
106 
107   // PreHeaderBB
108   DT.addNewBlock(PreHeaderBB, GuardBB);
109   Builder.SetInsertPoint(PreHeaderBB);
110   Builder.CreateBr(HeaderBB);
111 
112   // HeaderBB
113   DT.addNewBlock(HeaderBB, PreHeaderBB);
114   Builder.SetInsertPoint(HeaderBB);
115   PHINode *IV = Builder.CreatePHI(LoopIVType, 2, "polly.indvar");
116   IV->addIncoming(LB, PreHeaderBB);
117   Stride = Builder.CreateZExtOrBitCast(Stride, LoopIVType);
118   Value *IncrementedIV = Builder.CreateNSWAdd(IV, Stride, "polly.indvar_next");
119   Value *LoopCondition;
120   UB = Builder.CreateSub(UB, Stride, "polly.adjust_ub");
121   LoopCondition = Builder.CreateICmp(Predicate, IV, UB);
122   LoopCondition->setName("polly.loop_cond");
123   Builder.CreateCondBr(LoopCondition, HeaderBB, ExitBB);
124   IV->addIncoming(IncrementedIV, HeaderBB);
125   DT.changeImmediateDominator(ExitBB, GuardBB);
126 
127   // The loop body should be added here.
128   Builder.SetInsertPoint(HeaderBB->getFirstNonPHI());
129   return IV;
130 }
131 
132 void OMPGenerator::createCallParallelLoopStart(
133     Value *SubFunction, Value *SubfunctionParam, Value *NumberOfThreads,
134     Value *LowerBound, Value *UpperBound, Value *Stride) {
135   Module *M = getModule();
136   const char *Name = "GOMP_parallel_loop_runtime_start";
137   Function *F = M->getFunction(Name);
138 
139   // If F is not available, declare it.
140   if (!F) {
141     Type *LongTy = getIntPtrTy();
142     GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
143 
144     Type *Params[] = {PointerType::getUnqual(FunctionType::get(
145                           Builder.getVoidTy(), Builder.getInt8PtrTy(), false)),
146                       Builder.getInt8PtrTy(), Builder.getInt32Ty(), LongTy,
147                       LongTy, LongTy};
148 
149     FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), Params, false);
150     F = Function::Create(Ty, Linkage, Name, M);
151   }
152 
153   Value *Args[] = {SubFunction, SubfunctionParam, NumberOfThreads,
154                    LowerBound,  UpperBound,       Stride};
155 
156   Builder.CreateCall(F, Args);
157 }
158 
159 Value *OMPGenerator::createCallLoopNext(Value *LowerBoundPtr,
160                                         Value *UpperBoundPtr) {
161   Module *M = getModule();
162   const char *Name = "GOMP_loop_runtime_next";
163   Function *F = M->getFunction(Name);
164 
165   // If F is not available, declare it.
166   if (!F) {
167     Type *LongPtrTy = PointerType::getUnqual(getIntPtrTy());
168     GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
169 
170     Type *Params[] = {LongPtrTy, LongPtrTy};
171 
172     FunctionType *Ty = FunctionType::get(Builder.getInt8Ty(), Params, false);
173     F = Function::Create(Ty, Linkage, Name, M);
174   }
175 
176   Value *Args[] = {LowerBoundPtr, UpperBoundPtr};
177 
178   Value *Return = Builder.CreateCall(F, Args);
179   Return = Builder.CreateICmpNE(
180       Return, Builder.CreateZExt(Builder.getFalse(), Return->getType()));
181   return Return;
182 }
183 
184 void OMPGenerator::createCallParallelEnd() {
185   const char *Name = "GOMP_parallel_end";
186   Module *M = getModule();
187   Function *F = M->getFunction(Name);
188 
189   // If F is not available, declare it.
190   if (!F) {
191     GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
192 
193     FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), false);
194     F = Function::Create(Ty, Linkage, Name, M);
195   }
196 
197   Builder.CreateCall(F);
198 }
199 
200 void OMPGenerator::createCallLoopEndNowait() {
201   const char *Name = "GOMP_loop_end_nowait";
202   Module *M = getModule();
203   Function *F = M->getFunction(Name);
204 
205   // If F is not available, declare it.
206   if (!F) {
207     GlobalValue::LinkageTypes Linkage = Function::ExternalLinkage;
208 
209     FunctionType *Ty = FunctionType::get(Builder.getVoidTy(), false);
210     F = Function::Create(Ty, Linkage, Name, M);
211   }
212 
213   Builder.CreateCall(F);
214 }
215 
216 IntegerType *OMPGenerator::getIntPtrTy() {
217   return P->getAnalysis<DataLayoutPass>().getDataLayout().getIntPtrType(
218       Builder.getContext());
219 }
220 
221 Module *OMPGenerator::getModule() {
222   return Builder.GetInsertBlock()->getParent()->getParent();
223 }
224 
225 Function *OMPGenerator::createSubfunctionDefinition() {
226   Module *M = getModule();
227   Function *F = Builder.GetInsertBlock()->getParent();
228   std::vector<Type *> Arguments(1, Builder.getInt8PtrTy());
229   FunctionType *FT = FunctionType::get(Builder.getVoidTy(), Arguments, false);
230   Function *FN = Function::Create(FT, Function::InternalLinkage,
231                                   F->getName() + ".omp_subfn", M);
232   // Do not run any polly pass on the new function.
233   P->getAnalysis<polly::ScopDetection>().markFunctionAsInvalid(FN);
234 
235   Function::arg_iterator AI = FN->arg_begin();
236   AI->setName("omp.userContext");
237 
238   return FN;
239 }
240 
241 Value *OMPGenerator::loadValuesIntoStruct(SetVector<Value *> &Values) {
242   std::vector<Type *> Members;
243 
244   for (unsigned i = 0; i < Values.size(); i++)
245     Members.push_back(Values[i]->getType());
246 
247   StructType *Ty = StructType::get(Builder.getContext(), Members);
248   Value *Struct = Builder.CreateAlloca(Ty, 0, "omp.userContext");
249 
250   for (unsigned i = 0; i < Values.size(); i++) {
251     Value *Address = Builder.CreateStructGEP(Struct, i);
252     Builder.CreateStore(Values[i], Address);
253   }
254 
255   return Struct;
256 }
257 
258 void OMPGenerator::extractValuesFromStruct(SetVector<Value *> OldValues,
259                                            Value *Struct,
260                                            ValueToValueMapTy &Map) {
261   for (unsigned i = 0; i < OldValues.size(); i++) {
262     Value *Address = Builder.CreateStructGEP(Struct, i);
263     Value *NewValue = Builder.CreateLoad(Address);
264     Map.insert(std::make_pair(OldValues[i], NewValue));
265   }
266 }
267 
268 Value *OMPGenerator::createSubfunction(Value *Stride, Value *StructData,
269                                        SetVector<Value *> Data,
270                                        ValueToValueMapTy &Map,
271                                        Function **SubFunction) {
272   Function *FN = createSubfunctionDefinition();
273 
274   BasicBlock *PrevBB, *HeaderBB, *ExitBB, *CheckNextBB, *LoadIVBoundsBB,
275       *AfterBB;
276   Value *LowerBoundPtr, *UpperBoundPtr, *UserContext, *Ret1, *HasNextSchedule,
277       *LowerBound, *UpperBound, *IV;
278   Type *IntPtrTy = getIntPtrTy();
279   LLVMContext &Context = FN->getContext();
280 
281   // Store the previous basic block.
282   PrevBB = Builder.GetInsertBlock();
283 
284   // Create basic blocks.
285   HeaderBB = BasicBlock::Create(Context, "omp.setup", FN);
286   ExitBB = BasicBlock::Create(Context, "omp.exit", FN);
287   CheckNextBB = BasicBlock::Create(Context, "omp.checkNext", FN);
288   LoadIVBoundsBB = BasicBlock::Create(Context, "omp.loadIVBounds", FN);
289 
290   DominatorTree &DT = P->getAnalysis<DominatorTreeWrapperPass>().getDomTree();
291   DT.addNewBlock(HeaderBB, PrevBB);
292   DT.addNewBlock(ExitBB, HeaderBB);
293   DT.addNewBlock(CheckNextBB, HeaderBB);
294   DT.addNewBlock(LoadIVBoundsBB, HeaderBB);
295 
296   // Fill up basic block HeaderBB.
297   Builder.SetInsertPoint(HeaderBB);
298   LowerBoundPtr = Builder.CreateAlloca(IntPtrTy, 0, "omp.lowerBoundPtr");
299   UpperBoundPtr = Builder.CreateAlloca(IntPtrTy, 0, "omp.upperBoundPtr");
300   UserContext = Builder.CreateBitCast(FN->arg_begin(), StructData->getType(),
301                                       "omp.userContext");
302 
303   extractValuesFromStruct(Data, UserContext, Map);
304   Builder.CreateBr(CheckNextBB);
305 
306   // Add code to check if another set of iterations will be executed.
307   Builder.SetInsertPoint(CheckNextBB);
308   Ret1 = createCallLoopNext(LowerBoundPtr, UpperBoundPtr);
309   HasNextSchedule = Builder.CreateTrunc(Ret1, Builder.getInt1Ty(),
310                                         "omp.hasNextScheduleBlock");
311   Builder.CreateCondBr(HasNextSchedule, LoadIVBoundsBB, ExitBB);
312 
313   // Add code to to load the iv bounds for this set of iterations.
314   Builder.SetInsertPoint(LoadIVBoundsBB);
315   LowerBound = Builder.CreateLoad(LowerBoundPtr, "omp.lowerBound");
316   UpperBound = Builder.CreateLoad(UpperBoundPtr, "omp.upperBound");
317 
318   // Subtract one as the upper bound provided by openmp is a < comparison
319   // whereas the codegenForSequential function creates a <= comparison.
320   UpperBound = Builder.CreateSub(UpperBound, ConstantInt::get(IntPtrTy, 1),
321                                  "omp.upperBoundAdjusted");
322 
323   Builder.CreateBr(CheckNextBB);
324   Builder.SetInsertPoint(--Builder.GetInsertPoint());
325   IV = createLoop(LowerBound, UpperBound, Stride, Builder, P, AfterBB,
326                   ICmpInst::ICMP_SLE);
327 
328   BasicBlock::iterator LoopBody = Builder.GetInsertPoint();
329   Builder.SetInsertPoint(AfterBB->begin());
330 
331   // Add code to terminate this openmp subfunction.
332   Builder.SetInsertPoint(ExitBB);
333   createCallLoopEndNowait();
334   Builder.CreateRetVoid();
335 
336   Builder.SetInsertPoint(LoopBody);
337   *SubFunction = FN;
338 
339   return IV;
340 }
341 
342 Value *OMPGenerator::createParallelLoop(Value *LowerBound, Value *UpperBound,
343                                         Value *Stride,
344                                         SetVector<Value *> &Values,
345                                         ValueToValueMapTy &Map,
346                                         BasicBlock::iterator *LoopBody) {
347   Value *Struct, *IV, *SubfunctionParam, *NumberOfThreads;
348   Function *SubFunction;
349 
350   Struct = loadValuesIntoStruct(Values);
351 
352   BasicBlock::iterator PrevInsertPoint = Builder.GetInsertPoint();
353   IV = createSubfunction(Stride, Struct, Values, Map, &SubFunction);
354   *LoopBody = Builder.GetInsertPoint();
355   Builder.SetInsertPoint(PrevInsertPoint);
356 
357   // Create call for GOMP_parallel_loop_runtime_start.
358   SubfunctionParam =
359       Builder.CreateBitCast(Struct, Builder.getInt8PtrTy(), "omp_data");
360 
361   NumberOfThreads = Builder.getInt32(0);
362 
363   // Add one as the upper bound provided by openmp is a < comparison
364   // whereas the codegenForSequential function creates a <= comparison.
365   UpperBound =
366       Builder.CreateAdd(UpperBound, ConstantInt::get(getIntPtrTy(), 1));
367 
368   createCallParallelLoopStart(SubFunction, SubfunctionParam, NumberOfThreads,
369                               LowerBound, UpperBound, Stride);
370   Builder.CreateCall(SubFunction, SubfunctionParam);
371   createCallParallelEnd();
372 
373   return IV;
374 }
375