1 //===- OpenMPIRBuilder.cpp - Builder for LLVM-IR for OpenMP directives ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 ///
10 /// This file implements the OpenMPIRBuilder class, which is used as a
11 /// convenient way to create LLVM instructions for OpenMP directives.
12 ///
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
16 #include "llvm/ADT/SmallSet.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "llvm/Analysis/AssumptionCache.h"
19 #include "llvm/Analysis/CodeMetrics.h"
20 #include "llvm/Analysis/LoopInfo.h"
21 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
22 #include "llvm/Analysis/ScalarEvolution.h"
23 #include "llvm/Analysis/TargetLibraryInfo.h"
24 #include "llvm/IR/CFG.h"
25 #include "llvm/IR/Constants.h"
26 #include "llvm/IR/DebugInfoMetadata.h"
27 #include "llvm/IR/GlobalVariable.h"
28 #include "llvm/IR/IRBuilder.h"
29 #include "llvm/IR/MDBuilder.h"
30 #include "llvm/IR/PassManager.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/MC/TargetRegistry.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Target/TargetMachine.h"
35 #include "llvm/Target/TargetOptions.h"
36 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
37 #include "llvm/Transforms/Utils/CodeExtractor.h"
38 #include "llvm/Transforms/Utils/LoopPeel.h"
39 #include "llvm/Transforms/Utils/UnrollLoop.h"
40 
41 #include <cstdint>
42 
43 #define DEBUG_TYPE "openmp-ir-builder"
44 
45 using namespace llvm;
46 using namespace omp;
47 
48 static cl::opt<bool>
49     OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
50                          cl::desc("Use optimistic attributes describing "
51                                   "'as-if' properties of runtime calls."),
52                          cl::init(false));
53 
54 static cl::opt<double> UnrollThresholdFactor(
55     "openmp-ir-builder-unroll-threshold-factor", cl::Hidden,
56     cl::desc("Factor for the unroll threshold to account for code "
57              "simplifications still taking place"),
58     cl::init(1.5));
59 
60 #ifndef NDEBUG
61 /// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
62 /// at position IP1 may change the meaning of IP2 or vice-versa. This is because
63 /// an InsertPoint stores the instruction before something is inserted. For
64 /// instance, if both point to the same instruction, two IRBuilders alternating
65 /// creating instruction will cause the instructions to be interleaved.
isConflictIP(IRBuilder<>::InsertPoint IP1,IRBuilder<>::InsertPoint IP2)66 static bool isConflictIP(IRBuilder<>::InsertPoint IP1,
67                          IRBuilder<>::InsertPoint IP2) {
68   if (!IP1.isSet() || !IP2.isSet())
69     return false;
70   return IP1.getBlock() == IP2.getBlock() && IP1.getPoint() == IP2.getPoint();
71 }
72 
isValidWorkshareLoopScheduleType(OMPScheduleType SchedType)73 static bool isValidWorkshareLoopScheduleType(OMPScheduleType SchedType) {
74   // Valid ordered/unordered and base algorithm combinations.
75   switch (SchedType & ~OMPScheduleType::MonotonicityMask) {
76   case OMPScheduleType::UnorderedStaticChunked:
77   case OMPScheduleType::UnorderedStatic:
78   case OMPScheduleType::UnorderedDynamicChunked:
79   case OMPScheduleType::UnorderedGuidedChunked:
80   case OMPScheduleType::UnorderedRuntime:
81   case OMPScheduleType::UnorderedAuto:
82   case OMPScheduleType::UnorderedTrapezoidal:
83   case OMPScheduleType::UnorderedGreedy:
84   case OMPScheduleType::UnorderedBalanced:
85   case OMPScheduleType::UnorderedGuidedIterativeChunked:
86   case OMPScheduleType::UnorderedGuidedAnalyticalChunked:
87   case OMPScheduleType::UnorderedSteal:
88   case OMPScheduleType::UnorderedStaticBalancedChunked:
89   case OMPScheduleType::UnorderedGuidedSimd:
90   case OMPScheduleType::UnorderedRuntimeSimd:
91   case OMPScheduleType::OrderedStaticChunked:
92   case OMPScheduleType::OrderedStatic:
93   case OMPScheduleType::OrderedDynamicChunked:
94   case OMPScheduleType::OrderedGuidedChunked:
95   case OMPScheduleType::OrderedRuntime:
96   case OMPScheduleType::OrderedAuto:
97   case OMPScheduleType::OrderdTrapezoidal:
98   case OMPScheduleType::NomergeUnorderedStaticChunked:
99   case OMPScheduleType::NomergeUnorderedStatic:
100   case OMPScheduleType::NomergeUnorderedDynamicChunked:
101   case OMPScheduleType::NomergeUnorderedGuidedChunked:
102   case OMPScheduleType::NomergeUnorderedRuntime:
103   case OMPScheduleType::NomergeUnorderedAuto:
104   case OMPScheduleType::NomergeUnorderedTrapezoidal:
105   case OMPScheduleType::NomergeUnorderedGreedy:
106   case OMPScheduleType::NomergeUnorderedBalanced:
107   case OMPScheduleType::NomergeUnorderedGuidedIterativeChunked:
108   case OMPScheduleType::NomergeUnorderedGuidedAnalyticalChunked:
109   case OMPScheduleType::NomergeUnorderedSteal:
110   case OMPScheduleType::NomergeOrderedStaticChunked:
111   case OMPScheduleType::NomergeOrderedStatic:
112   case OMPScheduleType::NomergeOrderedDynamicChunked:
113   case OMPScheduleType::NomergeOrderedGuidedChunked:
114   case OMPScheduleType::NomergeOrderedRuntime:
115   case OMPScheduleType::NomergeOrderedAuto:
116   case OMPScheduleType::NomergeOrderedTrapezoidal:
117     break;
118   default:
119     return false;
120   }
121 
122   // Must not set both monotonicity modifiers at the same time.
123   OMPScheduleType MonotonicityFlags =
124       SchedType & OMPScheduleType::MonotonicityMask;
125   if (MonotonicityFlags == OMPScheduleType::MonotonicityMask)
126     return false;
127 
128   return true;
129 }
130 #endif
131 
132 /// Determine which scheduling algorithm to use, determined from schedule clause
133 /// arguments.
134 static OMPScheduleType
getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind,bool HasChunks,bool HasSimdModifier)135 getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
136                           bool HasSimdModifier) {
137   // Currently, the default schedule it static.
138   switch (ClauseKind) {
139   case OMP_SCHEDULE_Default:
140   case OMP_SCHEDULE_Static:
141     return HasChunks ? OMPScheduleType::BaseStaticChunked
142                      : OMPScheduleType::BaseStatic;
143   case OMP_SCHEDULE_Dynamic:
144     return OMPScheduleType::BaseDynamicChunked;
145   case OMP_SCHEDULE_Guided:
146     return HasSimdModifier ? OMPScheduleType::BaseGuidedSimd
147                            : OMPScheduleType::BaseGuidedChunked;
148   case OMP_SCHEDULE_Auto:
149     return llvm::omp::OMPScheduleType::BaseAuto;
150   case OMP_SCHEDULE_Runtime:
151     return HasSimdModifier ? OMPScheduleType::BaseRuntimeSimd
152                            : OMPScheduleType::BaseRuntime;
153   }
154   llvm_unreachable("unhandled schedule clause argument");
155 }
156 
157 /// Adds ordering modifier flags to schedule type.
158 static OMPScheduleType
getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,bool HasOrderedClause)159 getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,
160                               bool HasOrderedClause) {
161   assert((BaseScheduleType & OMPScheduleType::ModifierMask) ==
162              OMPScheduleType::None &&
163          "Must not have ordering nor monotonicity flags already set");
164 
165   OMPScheduleType OrderingModifier = HasOrderedClause
166                                          ? OMPScheduleType::ModifierOrdered
167                                          : OMPScheduleType::ModifierUnordered;
168   OMPScheduleType OrderingScheduleType = BaseScheduleType | OrderingModifier;
169 
170   // Unsupported combinations
171   if (OrderingScheduleType ==
172       (OMPScheduleType::BaseGuidedSimd | OMPScheduleType::ModifierOrdered))
173     return OMPScheduleType::OrderedGuidedChunked;
174   else if (OrderingScheduleType == (OMPScheduleType::BaseRuntimeSimd |
175                                     OMPScheduleType::ModifierOrdered))
176     return OMPScheduleType::OrderedRuntime;
177 
178   return OrderingScheduleType;
179 }
180 
181 /// Adds monotonicity modifier flags to schedule type.
182 static OMPScheduleType
getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,bool HasSimdModifier,bool HasMonotonic,bool HasNonmonotonic,bool HasOrderedClause)183 getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,
184                                   bool HasSimdModifier, bool HasMonotonic,
185                                   bool HasNonmonotonic, bool HasOrderedClause) {
186   assert((ScheduleType & OMPScheduleType::MonotonicityMask) ==
187              OMPScheduleType::None &&
188          "Must not have monotonicity flags already set");
189   assert((!HasMonotonic || !HasNonmonotonic) &&
190          "Monotonic and Nonmonotonic are contradicting each other");
191 
192   if (HasMonotonic) {
193     return ScheduleType | OMPScheduleType::ModifierMonotonic;
194   } else if (HasNonmonotonic) {
195     return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
196   } else {
197     // OpenMP 5.1, 2.11.4 Worksharing-Loop Construct, Description.
198     // If the static schedule kind is specified or if the ordered clause is
199     // specified, and if the nonmonotonic modifier is not specified, the
200     // effect is as if the monotonic modifier is specified. Otherwise, unless
201     // the monotonic modifier is specified, the effect is as if the
202     // nonmonotonic modifier is specified.
203     OMPScheduleType BaseScheduleType =
204         ScheduleType & ~OMPScheduleType::ModifierMask;
205     if ((BaseScheduleType == OMPScheduleType::BaseStatic) ||
206         (BaseScheduleType == OMPScheduleType::BaseStaticChunked) ||
207         HasOrderedClause) {
208       // The monotonic is used by default in openmp runtime library, so no need
209       // to set it.
210       return ScheduleType;
211     } else {
212       return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
213     }
214   }
215 }
216 
217 /// Determine the schedule type using schedule and ordering clause arguments.
218 static OMPScheduleType
computeOpenMPScheduleType(ScheduleKind ClauseKind,bool HasChunks,bool HasSimdModifier,bool HasMonotonicModifier,bool HasNonmonotonicModifier,bool HasOrderedClause)219 computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
220                           bool HasSimdModifier, bool HasMonotonicModifier,
221                           bool HasNonmonotonicModifier, bool HasOrderedClause) {
222   OMPScheduleType BaseSchedule =
223       getOpenMPBaseScheduleType(ClauseKind, HasChunks, HasSimdModifier);
224   OMPScheduleType OrderedSchedule =
225       getOpenMPOrderingScheduleType(BaseSchedule, HasOrderedClause);
226   OMPScheduleType Result = getOpenMPMonotonicityScheduleType(
227       OrderedSchedule, HasSimdModifier, HasMonotonicModifier,
228       HasNonmonotonicModifier, HasOrderedClause);
229 
230   assert(isValidWorkshareLoopScheduleType(Result));
231   return Result;
232 }
233 
234 /// Make \p Source branch to \p Target.
235 ///
236 /// Handles two situations:
237 /// * \p Source already has an unconditional branch.
238 /// * \p Source is a degenerate block (no terminator because the BB is
239 ///             the current head of the IR construction).
redirectTo(BasicBlock * Source,BasicBlock * Target,DebugLoc DL)240 static void redirectTo(BasicBlock *Source, BasicBlock *Target, DebugLoc DL) {
241   if (Instruction *Term = Source->getTerminator()) {
242     auto *Br = cast<BranchInst>(Term);
243     assert(!Br->isConditional() &&
244            "BB's terminator must be an unconditional branch (or degenerate)");
245     BasicBlock *Succ = Br->getSuccessor(0);
246     Succ->removePredecessor(Source, /*KeepOneInputPHIs=*/true);
247     Br->setSuccessor(0, Target);
248     return;
249   }
250 
251   auto *NewBr = BranchInst::Create(Target, Source);
252   NewBr->setDebugLoc(DL);
253 }
254 
spliceBB(IRBuilderBase::InsertPoint IP,BasicBlock * New,bool CreateBranch)255 void llvm::spliceBB(IRBuilderBase::InsertPoint IP, BasicBlock *New,
256                     bool CreateBranch) {
257   assert(New->getFirstInsertionPt() == New->begin() &&
258          "Target BB must not have PHI nodes");
259 
260   // Move instructions to new block.
261   BasicBlock *Old = IP.getBlock();
262   New->getInstList().splice(New->begin(), Old->getInstList(), IP.getPoint(),
263                             Old->end());
264 
265   if (CreateBranch)
266     BranchInst::Create(New, Old);
267 }
268 
spliceBB(IRBuilder<> & Builder,BasicBlock * New,bool CreateBranch)269 void llvm::spliceBB(IRBuilder<> &Builder, BasicBlock *New, bool CreateBranch) {
270   DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
271   BasicBlock *Old = Builder.GetInsertBlock();
272 
273   spliceBB(Builder.saveIP(), New, CreateBranch);
274   if (CreateBranch)
275     Builder.SetInsertPoint(Old->getTerminator());
276   else
277     Builder.SetInsertPoint(Old);
278 
279   // SetInsertPoint also updates the Builder's debug location, but we want to
280   // keep the one the Builder was configured to use.
281   Builder.SetCurrentDebugLocation(DebugLoc);
282 }
283 
splitBB(IRBuilderBase::InsertPoint IP,bool CreateBranch,llvm::Twine Name)284 BasicBlock *llvm::splitBB(IRBuilderBase::InsertPoint IP, bool CreateBranch,
285                           llvm::Twine Name) {
286   BasicBlock *Old = IP.getBlock();
287   BasicBlock *New = BasicBlock::Create(
288       Old->getContext(), Name.isTriviallyEmpty() ? Old->getName() : Name,
289       Old->getParent(), Old->getNextNode());
290   spliceBB(IP, New, CreateBranch);
291   New->replaceSuccessorsPhiUsesWith(Old, New);
292   return New;
293 }
294 
splitBB(IRBuilderBase & Builder,bool CreateBranch,llvm::Twine Name)295 BasicBlock *llvm::splitBB(IRBuilderBase &Builder, bool CreateBranch,
296                           llvm::Twine Name) {
297   DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
298   BasicBlock *New = splitBB(Builder.saveIP(), CreateBranch, Name);
299   if (CreateBranch)
300     Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
301   else
302     Builder.SetInsertPoint(Builder.GetInsertBlock());
303   // SetInsertPoint also updates the Builder's debug location, but we want to
304   // keep the one the Builder was configured to use.
305   Builder.SetCurrentDebugLocation(DebugLoc);
306   return New;
307 }
308 
splitBB(IRBuilder<> & Builder,bool CreateBranch,llvm::Twine Name)309 BasicBlock *llvm::splitBB(IRBuilder<> &Builder, bool CreateBranch,
310                           llvm::Twine Name) {
311   DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
312   BasicBlock *New = splitBB(Builder.saveIP(), CreateBranch, Name);
313   if (CreateBranch)
314     Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
315   else
316     Builder.SetInsertPoint(Builder.GetInsertBlock());
317   // SetInsertPoint also updates the Builder's debug location, but we want to
318   // keep the one the Builder was configured to use.
319   Builder.SetCurrentDebugLocation(DebugLoc);
320   return New;
321 }
322 
splitBBWithSuffix(IRBuilderBase & Builder,bool CreateBranch,llvm::Twine Suffix)323 BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
324                                     llvm::Twine Suffix) {
325   BasicBlock *Old = Builder.GetInsertBlock();
326   return splitBB(Builder, CreateBranch, Old->getName() + Suffix);
327 }
328 
addAttributes(omp::RuntimeFunction FnID,Function & Fn)329 void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
330   LLVMContext &Ctx = Fn.getContext();
331 
332   // Get the function's current attributes.
333   auto Attrs = Fn.getAttributes();
334   auto FnAttrs = Attrs.getFnAttrs();
335   auto RetAttrs = Attrs.getRetAttrs();
336   SmallVector<AttributeSet, 4> ArgAttrs;
337   for (size_t ArgNo = 0; ArgNo < Fn.arg_size(); ++ArgNo)
338     ArgAttrs.emplace_back(Attrs.getParamAttrs(ArgNo));
339 
340 #define OMP_ATTRS_SET(VarName, AttrSet) AttributeSet VarName = AttrSet;
341 #include "llvm/Frontend/OpenMP/OMPKinds.def"
342 
343   // Add attributes to the function declaration.
344   switch (FnID) {
345 #define OMP_RTL_ATTRS(Enum, FnAttrSet, RetAttrSet, ArgAttrSets)                \
346   case Enum:                                                                   \
347     FnAttrs = FnAttrs.addAttributes(Ctx, FnAttrSet);                           \
348     RetAttrs = RetAttrs.addAttributes(Ctx, RetAttrSet);                        \
349     for (size_t ArgNo = 0; ArgNo < ArgAttrSets.size(); ++ArgNo)                \
350       ArgAttrs[ArgNo] =                                                        \
351           ArgAttrs[ArgNo].addAttributes(Ctx, ArgAttrSets[ArgNo]);              \
352     Fn.setAttributes(AttributeList::get(Ctx, FnAttrs, RetAttrs, ArgAttrs));    \
353     break;
354 #include "llvm/Frontend/OpenMP/OMPKinds.def"
355   default:
356     // Attributes are optional.
357     break;
358   }
359 }
360 
361 FunctionCallee
getOrCreateRuntimeFunction(Module & M,RuntimeFunction FnID)362 OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
363   FunctionType *FnTy = nullptr;
364   Function *Fn = nullptr;
365 
366   // Try to find the declation in the module first.
367   switch (FnID) {
368 #define OMP_RTL(Enum, Str, IsVarArg, ReturnType, ...)                          \
369   case Enum:                                                                   \
370     FnTy = FunctionType::get(ReturnType, ArrayRef<Type *>{__VA_ARGS__},        \
371                              IsVarArg);                                        \
372     Fn = M.getFunction(Str);                                                   \
373     break;
374 #include "llvm/Frontend/OpenMP/OMPKinds.def"
375   }
376 
377   if (!Fn) {
378     // Create a new declaration if we need one.
379     switch (FnID) {
380 #define OMP_RTL(Enum, Str, ...)                                                \
381   case Enum:                                                                   \
382     Fn = Function::Create(FnTy, GlobalValue::ExternalLinkage, Str, M);         \
383     break;
384 #include "llvm/Frontend/OpenMP/OMPKinds.def"
385     }
386 
387     // Add information if the runtime function takes a callback function
388     if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
389       if (!Fn->hasMetadata(LLVMContext::MD_callback)) {
390         LLVMContext &Ctx = Fn->getContext();
391         MDBuilder MDB(Ctx);
392         // Annotate the callback behavior of the runtime function:
393         //  - The callback callee is argument number 2 (microtask).
394         //  - The first two arguments of the callback callee are unknown (-1).
395         //  - All variadic arguments to the runtime function are passed to the
396         //    callback callee.
397         Fn->addMetadata(
398             LLVMContext::MD_callback,
399             *MDNode::get(Ctx, {MDB.createCallbackEncoding(
400                                   2, {-1, -1}, /* VarArgsArePassed */ true)}));
401       }
402     }
403 
404     LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
405                       << " with type " << *Fn->getFunctionType() << "\n");
406     addAttributes(FnID, *Fn);
407 
408   } else {
409     LLVM_DEBUG(dbgs() << "Found OpenMP runtime function " << Fn->getName()
410                       << " with type " << *Fn->getFunctionType() << "\n");
411   }
412 
413   assert(Fn && "Failed to create OpenMP runtime function");
414 
415   // Cast the function to the expected type if necessary
416   Constant *C = ConstantExpr::getBitCast(Fn, FnTy->getPointerTo());
417   return {FnTy, C};
418 }
419 
getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID)420 Function *OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID) {
421   FunctionCallee RTLFn = getOrCreateRuntimeFunction(M, FnID);
422   auto *Fn = dyn_cast<llvm::Function>(RTLFn.getCallee());
423   assert(Fn && "Failed to create OpenMP runtime function pointer");
424   return Fn;
425 }
426 
initialize()427 void OpenMPIRBuilder::initialize() { initializeTypes(M); }
428 
finalize(Function * Fn)429 void OpenMPIRBuilder::finalize(Function *Fn) {
430   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
431   SmallVector<BasicBlock *, 32> Blocks;
432   SmallVector<OutlineInfo, 16> DeferredOutlines;
433   for (OutlineInfo &OI : OutlineInfos) {
434     // Skip functions that have not finalized yet; may happen with nested
435     // function generation.
436     if (Fn && OI.getFunction() != Fn) {
437       DeferredOutlines.push_back(OI);
438       continue;
439     }
440 
441     ParallelRegionBlockSet.clear();
442     Blocks.clear();
443     OI.collectBlocks(ParallelRegionBlockSet, Blocks);
444 
445     Function *OuterFn = OI.getFunction();
446     CodeExtractorAnalysisCache CEAC(*OuterFn);
447     CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
448                             /* AggregateArgs */ true,
449                             /* BlockFrequencyInfo */ nullptr,
450                             /* BranchProbabilityInfo */ nullptr,
451                             /* AssumptionCache */ nullptr,
452                             /* AllowVarArgs */ true,
453                             /* AllowAlloca */ true,
454                             /* AllocaBlock*/ OI.OuterAllocaBB,
455                             /* Suffix */ ".omp_par");
456 
457     LLVM_DEBUG(dbgs() << "Before     outlining: " << *OuterFn << "\n");
458     LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
459                       << " Exit: " << OI.ExitBB->getName() << "\n");
460     assert(Extractor.isEligible() &&
461            "Expected OpenMP outlining to be possible!");
462 
463     for (auto *V : OI.ExcludeArgsFromAggregate)
464       Extractor.excludeArgFromAggregate(V);
465 
466     Function *OutlinedFn = Extractor.extractCodeRegion(CEAC);
467 
468     LLVM_DEBUG(dbgs() << "After      outlining: " << *OuterFn << "\n");
469     LLVM_DEBUG(dbgs() << "   Outlined function: " << *OutlinedFn << "\n");
470     assert(OutlinedFn->getReturnType()->isVoidTy() &&
471            "OpenMP outlined functions should not return a value!");
472 
473     // For compability with the clang CG we move the outlined function after the
474     // one with the parallel region.
475     OutlinedFn->removeFromParent();
476     M.getFunctionList().insertAfter(OuterFn->getIterator(), OutlinedFn);
477 
478     // Remove the artificial entry introduced by the extractor right away, we
479     // made our own entry block after all.
480     {
481       BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
482       assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
483       assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
484       // Move instructions from the to-be-deleted ArtificialEntry to the entry
485       // basic block of the parallel region. CodeExtractor generates
486       // instructions to unwrap the aggregate argument and may sink
487       // allocas/bitcasts for values that are solely used in the outlined region
488       // and do not escape.
489       assert(!ArtificialEntry.empty() &&
490              "Expected instructions to add in the outlined region entry");
491       for (BasicBlock::reverse_iterator It = ArtificialEntry.rbegin(),
492                                         End = ArtificialEntry.rend();
493            It != End;) {
494         Instruction &I = *It;
495         It++;
496 
497         if (I.isTerminator())
498           continue;
499 
500         I.moveBefore(*OI.EntryBB, OI.EntryBB->getFirstInsertionPt());
501       }
502 
503       OI.EntryBB->moveBefore(&ArtificialEntry);
504       ArtificialEntry.eraseFromParent();
505     }
506     assert(&OutlinedFn->getEntryBlock() == OI.EntryBB);
507     assert(OutlinedFn && OutlinedFn->getNumUses() == 1);
508 
509     // Run a user callback, e.g. to add attributes.
510     if (OI.PostOutlineCB)
511       OI.PostOutlineCB(*OutlinedFn);
512   }
513 
514   // Remove work items that have been completed.
515   OutlineInfos = std::move(DeferredOutlines);
516 }
517 
~OpenMPIRBuilder()518 OpenMPIRBuilder::~OpenMPIRBuilder() {
519   assert(OutlineInfos.empty() && "There must be no outstanding outlinings");
520 }
521 
createGlobalFlag(unsigned Value,StringRef Name)522 GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
523   IntegerType *I32Ty = Type::getInt32Ty(M.getContext());
524   auto *GV =
525       new GlobalVariable(M, I32Ty,
526                          /* isConstant = */ true, GlobalValue::WeakODRLinkage,
527                          ConstantInt::get(I32Ty, Value), Name);
528   GV->setVisibility(GlobalValue::HiddenVisibility);
529 
530   return GV;
531 }
532 
getOrCreateIdent(Constant * SrcLocStr,uint32_t SrcLocStrSize,IdentFlag LocFlags,unsigned Reserve2Flags)533 Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
534                                             uint32_t SrcLocStrSize,
535                                             IdentFlag LocFlags,
536                                             unsigned Reserve2Flags) {
537   // Enable "C-mode".
538   LocFlags |= OMP_IDENT_FLAG_KMPC;
539 
540   Constant *&Ident =
541       IdentMap[{SrcLocStr, uint64_t(LocFlags) << 31 | Reserve2Flags}];
542   if (!Ident) {
543     Constant *I32Null = ConstantInt::getNullValue(Int32);
544     Constant *IdentData[] = {I32Null,
545                              ConstantInt::get(Int32, uint32_t(LocFlags)),
546                              ConstantInt::get(Int32, Reserve2Flags),
547                              ConstantInt::get(Int32, SrcLocStrSize), SrcLocStr};
548     Constant *Initializer =
549         ConstantStruct::get(OpenMPIRBuilder::Ident, IdentData);
550 
551     // Look for existing encoding of the location + flags, not needed but
552     // minimizes the difference to the existing solution while we transition.
553     for (GlobalVariable &GV : M.getGlobalList())
554       if (GV.getValueType() == OpenMPIRBuilder::Ident && GV.hasInitializer())
555         if (GV.getInitializer() == Initializer)
556           Ident = &GV;
557 
558     if (!Ident) {
559       auto *GV = new GlobalVariable(
560           M, OpenMPIRBuilder::Ident,
561           /* isConstant = */ true, GlobalValue::PrivateLinkage, Initializer, "",
562           nullptr, GlobalValue::NotThreadLocal,
563           M.getDataLayout().getDefaultGlobalsAddressSpace());
564       GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
565       GV->setAlignment(Align(8));
566       Ident = GV;
567     }
568   }
569 
570   return ConstantExpr::getPointerBitCastOrAddrSpaceCast(Ident, IdentPtr);
571 }
572 
getOrCreateSrcLocStr(StringRef LocStr,uint32_t & SrcLocStrSize)573 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef LocStr,
574                                                 uint32_t &SrcLocStrSize) {
575   SrcLocStrSize = LocStr.size();
576   Constant *&SrcLocStr = SrcLocStrMap[LocStr];
577   if (!SrcLocStr) {
578     Constant *Initializer =
579         ConstantDataArray::getString(M.getContext(), LocStr);
580 
581     // Look for existing encoding of the location, not needed but minimizes the
582     // difference to the existing solution while we transition.
583     for (GlobalVariable &GV : M.getGlobalList())
584       if (GV.isConstant() && GV.hasInitializer() &&
585           GV.getInitializer() == Initializer)
586         return SrcLocStr = ConstantExpr::getPointerCast(&GV, Int8Ptr);
587 
588     SrcLocStr = Builder.CreateGlobalStringPtr(LocStr, /* Name */ "",
589                                               /* AddressSpace */ 0, &M);
590   }
591   return SrcLocStr;
592 }
593 
getOrCreateSrcLocStr(StringRef FunctionName,StringRef FileName,unsigned Line,unsigned Column,uint32_t & SrcLocStrSize)594 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef FunctionName,
595                                                 StringRef FileName,
596                                                 unsigned Line, unsigned Column,
597                                                 uint32_t &SrcLocStrSize) {
598   SmallString<128> Buffer;
599   Buffer.push_back(';');
600   Buffer.append(FileName);
601   Buffer.push_back(';');
602   Buffer.append(FunctionName);
603   Buffer.push_back(';');
604   Buffer.append(std::to_string(Line));
605   Buffer.push_back(';');
606   Buffer.append(std::to_string(Column));
607   Buffer.push_back(';');
608   Buffer.push_back(';');
609   return getOrCreateSrcLocStr(Buffer.str(), SrcLocStrSize);
610 }
611 
612 Constant *
getOrCreateDefaultSrcLocStr(uint32_t & SrcLocStrSize)613 OpenMPIRBuilder::getOrCreateDefaultSrcLocStr(uint32_t &SrcLocStrSize) {
614   StringRef UnknownLoc = ";unknown;unknown;0;0;;";
615   return getOrCreateSrcLocStr(UnknownLoc, SrcLocStrSize);
616 }
617 
getOrCreateSrcLocStr(DebugLoc DL,uint32_t & SrcLocStrSize,Function * F)618 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(DebugLoc DL,
619                                                 uint32_t &SrcLocStrSize,
620                                                 Function *F) {
621   DILocation *DIL = DL.get();
622   if (!DIL)
623     return getOrCreateDefaultSrcLocStr(SrcLocStrSize);
624   StringRef FileName = M.getName();
625   if (DIFile *DIF = DIL->getFile())
626     if (Optional<StringRef> Source = DIF->getSource())
627       FileName = *Source;
628   StringRef Function = DIL->getScope()->getSubprogram()->getName();
629   if (Function.empty() && F)
630     Function = F->getName();
631   return getOrCreateSrcLocStr(Function, FileName, DIL->getLine(),
632                               DIL->getColumn(), SrcLocStrSize);
633 }
634 
getOrCreateSrcLocStr(const LocationDescription & Loc,uint32_t & SrcLocStrSize)635 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(const LocationDescription &Loc,
636                                                 uint32_t &SrcLocStrSize) {
637   return getOrCreateSrcLocStr(Loc.DL, SrcLocStrSize,
638                               Loc.IP.getBlock()->getParent());
639 }
640 
getOrCreateThreadID(Value * Ident)641 Value *OpenMPIRBuilder::getOrCreateThreadID(Value *Ident) {
642   return Builder.CreateCall(
643       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num), Ident,
644       "omp_global_thread_num");
645 }
646 
647 OpenMPIRBuilder::InsertPointTy
createBarrier(const LocationDescription & Loc,Directive DK,bool ForceSimpleCall,bool CheckCancelFlag)648 OpenMPIRBuilder::createBarrier(const LocationDescription &Loc, Directive DK,
649                                bool ForceSimpleCall, bool CheckCancelFlag) {
650   if (!updateToLocation(Loc))
651     return Loc.IP;
652   return emitBarrierImpl(Loc, DK, ForceSimpleCall, CheckCancelFlag);
653 }
654 
655 OpenMPIRBuilder::InsertPointTy
emitBarrierImpl(const LocationDescription & Loc,Directive Kind,bool ForceSimpleCall,bool CheckCancelFlag)656 OpenMPIRBuilder::emitBarrierImpl(const LocationDescription &Loc, Directive Kind,
657                                  bool ForceSimpleCall, bool CheckCancelFlag) {
658   // Build call __kmpc_cancel_barrier(loc, thread_id) or
659   //            __kmpc_barrier(loc, thread_id);
660 
661   IdentFlag BarrierLocFlags;
662   switch (Kind) {
663   case OMPD_for:
664     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_FOR;
665     break;
666   case OMPD_sections:
667     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SECTIONS;
668     break;
669   case OMPD_single:
670     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SINGLE;
671     break;
672   case OMPD_barrier:
673     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_EXPL;
674     break;
675   default:
676     BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL;
677     break;
678   }
679 
680   uint32_t SrcLocStrSize;
681   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
682   Value *Args[] = {
683       getOrCreateIdent(SrcLocStr, SrcLocStrSize, BarrierLocFlags),
684       getOrCreateThreadID(getOrCreateIdent(SrcLocStr, SrcLocStrSize))};
685 
686   // If we are in a cancellable parallel region, barriers are cancellation
687   // points.
688   // TODO: Check why we would force simple calls or to ignore the cancel flag.
689   bool UseCancelBarrier =
690       !ForceSimpleCall && isLastFinalizationInfoCancellable(OMPD_parallel);
691 
692   Value *Result =
693       Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
694                              UseCancelBarrier ? OMPRTL___kmpc_cancel_barrier
695                                               : OMPRTL___kmpc_barrier),
696                          Args);
697 
698   if (UseCancelBarrier && CheckCancelFlag)
699     emitCancelationCheckImpl(Result, OMPD_parallel);
700 
701   return Builder.saveIP();
702 }
703 
704 OpenMPIRBuilder::InsertPointTy
createCancel(const LocationDescription & Loc,Value * IfCondition,omp::Directive CanceledDirective)705 OpenMPIRBuilder::createCancel(const LocationDescription &Loc,
706                               Value *IfCondition,
707                               omp::Directive CanceledDirective) {
708   if (!updateToLocation(Loc))
709     return Loc.IP;
710 
711   // LLVM utilities like blocks with terminators.
712   auto *UI = Builder.CreateUnreachable();
713 
714   Instruction *ThenTI = UI, *ElseTI = nullptr;
715   if (IfCondition)
716     SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI);
717   Builder.SetInsertPoint(ThenTI);
718 
719   Value *CancelKind = nullptr;
720   switch (CanceledDirective) {
721 #define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value)                       \
722   case DirectiveEnum:                                                          \
723     CancelKind = Builder.getInt32(Value);                                      \
724     break;
725 #include "llvm/Frontend/OpenMP/OMPKinds.def"
726   default:
727     llvm_unreachable("Unknown cancel kind!");
728   }
729 
730   uint32_t SrcLocStrSize;
731   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
732   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
733   Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
734   Value *Result = Builder.CreateCall(
735       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_cancel), Args);
736   auto ExitCB = [this, CanceledDirective, Loc](InsertPointTy IP) {
737     if (CanceledDirective == OMPD_parallel) {
738       IRBuilder<>::InsertPointGuard IPG(Builder);
739       Builder.restoreIP(IP);
740       createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
741                     omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
742                     /* CheckCancelFlag */ false);
743     }
744   };
745 
746   // The actual cancel logic is shared with others, e.g., cancel_barriers.
747   emitCancelationCheckImpl(Result, CanceledDirective, ExitCB);
748 
749   // Update the insertion point and remove the terminator we introduced.
750   Builder.SetInsertPoint(UI->getParent());
751   UI->eraseFromParent();
752 
753   return Builder.saveIP();
754 }
755 
emitOffloadingEntry(Constant * Addr,StringRef Name,uint64_t Size,int32_t Flags,StringRef SectionName)756 void OpenMPIRBuilder::emitOffloadingEntry(Constant *Addr, StringRef Name,
757                                           uint64_t Size, int32_t Flags,
758                                           StringRef SectionName) {
759   Type *Int8PtrTy = Type::getInt8PtrTy(M.getContext());
760   Type *Int32Ty = Type::getInt32Ty(M.getContext());
761   Type *SizeTy = M.getDataLayout().getIntPtrType(M.getContext());
762 
763   Constant *AddrName = ConstantDataArray::getString(M.getContext(), Name);
764 
765   // Create the constant string used to look up the symbol in the device.
766   auto *Str =
767       new llvm::GlobalVariable(M, AddrName->getType(), /*isConstant=*/true,
768                                llvm::GlobalValue::InternalLinkage, AddrName,
769                                ".omp_offloading.entry_name");
770   Str->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
771 
772   // Construct the offloading entry.
773   Constant *EntryData[] = {
774       ConstantExpr::getPointerBitCastOrAddrSpaceCast(Addr, Int8PtrTy),
775       ConstantExpr::getPointerBitCastOrAddrSpaceCast(Str, Int8PtrTy),
776       ConstantInt::get(SizeTy, Size),
777       ConstantInt::get(Int32Ty, Flags),
778       ConstantInt::get(Int32Ty, 0),
779   };
780   Constant *EntryInitializer =
781       ConstantStruct::get(OpenMPIRBuilder::OffloadEntry, EntryData);
782 
783   auto *Entry = new GlobalVariable(
784       M, OpenMPIRBuilder::OffloadEntry,
785       /* isConstant = */ true, GlobalValue::WeakAnyLinkage, EntryInitializer,
786       ".omp_offloading.entry." + Name, nullptr, GlobalValue::NotThreadLocal,
787       M.getDataLayout().getDefaultGlobalsAddressSpace());
788 
789   // The entry has to be created in the section the linker expects it to be.
790   Entry->setSection(SectionName);
791   Entry->setAlignment(Align(1));
792 }
793 
emitTargetKernel(const LocationDescription & Loc,Value * & Return,Value * Ident,Value * DeviceID,Value * NumTeams,Value * NumThreads,Value * HostPtr,ArrayRef<Value * > KernelArgs,ArrayRef<Value * > NoWaitArgs)794 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
795     const LocationDescription &Loc, Value *&Return, Value *Ident,
796     Value *DeviceID, Value *NumTeams, Value *NumThreads, Value *HostPtr,
797     ArrayRef<Value *> KernelArgs, ArrayRef<Value *> NoWaitArgs) {
798   if (!updateToLocation(Loc))
799     return Loc.IP;
800 
801   auto *KernelArgsPtr =
802       Builder.CreateAlloca(OpenMPIRBuilder::KernelArgs, nullptr, "kernel_args");
803   for (unsigned I = 0, Size = KernelArgs.size(); I != Size; ++I) {
804     llvm::Value *Arg =
805         Builder.CreateStructGEP(OpenMPIRBuilder::KernelArgs, KernelArgsPtr, I);
806     Builder.CreateAlignedStore(
807         KernelArgs[I], Arg,
808         M.getDataLayout().getPrefTypeAlign(KernelArgs[I]->getType()));
809   }
810 
811   bool HasNoWait = !NoWaitArgs.empty();
812   SmallVector<Value *> OffloadingArgs{Ident,      DeviceID, NumTeams,
813                                       NumThreads, HostPtr,  KernelArgsPtr};
814   if (HasNoWait)
815     OffloadingArgs.append(NoWaitArgs.begin(), NoWaitArgs.end());
816 
817   Return = Builder.CreateCall(
818       HasNoWait
819           ? getOrCreateRuntimeFunction(M, OMPRTL___tgt_target_kernel_nowait)
820           : getOrCreateRuntimeFunction(M, OMPRTL___tgt_target_kernel),
821       OffloadingArgs);
822 
823   return Builder.saveIP();
824 }
825 
emitCancelationCheckImpl(Value * CancelFlag,omp::Directive CanceledDirective,FinalizeCallbackTy ExitCB)826 void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag,
827                                                omp::Directive CanceledDirective,
828                                                FinalizeCallbackTy ExitCB) {
829   assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
830          "Unexpected cancellation!");
831 
832   // For a cancel barrier we create two new blocks.
833   BasicBlock *BB = Builder.GetInsertBlock();
834   BasicBlock *NonCancellationBlock;
835   if (Builder.GetInsertPoint() == BB->end()) {
836     // TODO: This branch will not be needed once we moved to the
837     // OpenMPIRBuilder codegen completely.
838     NonCancellationBlock = BasicBlock::Create(
839         BB->getContext(), BB->getName() + ".cont", BB->getParent());
840   } else {
841     NonCancellationBlock = SplitBlock(BB, &*Builder.GetInsertPoint());
842     BB->getTerminator()->eraseFromParent();
843     Builder.SetInsertPoint(BB);
844   }
845   BasicBlock *CancellationBlock = BasicBlock::Create(
846       BB->getContext(), BB->getName() + ".cncl", BB->getParent());
847 
848   // Jump to them based on the return value.
849   Value *Cmp = Builder.CreateIsNull(CancelFlag);
850   Builder.CreateCondBr(Cmp, NonCancellationBlock, CancellationBlock,
851                        /* TODO weight */ nullptr, nullptr);
852 
853   // From the cancellation block we finalize all variables and go to the
854   // post finalization block that is known to the FiniCB callback.
855   Builder.SetInsertPoint(CancellationBlock);
856   if (ExitCB)
857     ExitCB(Builder.saveIP());
858   auto &FI = FinalizationStack.back();
859   FI.FiniCB(Builder.saveIP());
860 
861   // The continuation block is where code generation continues.
862   Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
863 }
864 
createParallel(const LocationDescription & Loc,InsertPointTy OuterAllocaIP,BodyGenCallbackTy BodyGenCB,PrivatizeCallbackTy PrivCB,FinalizeCallbackTy FiniCB,Value * IfCondition,Value * NumThreads,omp::ProcBindKind ProcBind,bool IsCancellable)865 IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
866     const LocationDescription &Loc, InsertPointTy OuterAllocaIP,
867     BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB,
868     FinalizeCallbackTy FiniCB, Value *IfCondition, Value *NumThreads,
869     omp::ProcBindKind ProcBind, bool IsCancellable) {
870   assert(!isConflictIP(Loc.IP, OuterAllocaIP) && "IPs must not be ambiguous");
871 
872   if (!updateToLocation(Loc))
873     return Loc.IP;
874 
875   uint32_t SrcLocStrSize;
876   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
877   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
878   Value *ThreadID = getOrCreateThreadID(Ident);
879 
880   if (NumThreads) {
881     // Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
882     Value *Args[] = {
883         Ident, ThreadID,
884         Builder.CreateIntCast(NumThreads, Int32, /*isSigned*/ false)};
885     Builder.CreateCall(
886         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_threads), Args);
887   }
888 
889   if (ProcBind != OMP_PROC_BIND_default) {
890     // Build call __kmpc_push_proc_bind(&Ident, global_tid, proc_bind)
891     Value *Args[] = {
892         Ident, ThreadID,
893         ConstantInt::get(Int32, unsigned(ProcBind), /*isSigned=*/true)};
894     Builder.CreateCall(
895         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_proc_bind), Args);
896   }
897 
898   BasicBlock *InsertBB = Builder.GetInsertBlock();
899   Function *OuterFn = InsertBB->getParent();
900 
901   // Save the outer alloca block because the insertion iterator may get
902   // invalidated and we still need this later.
903   BasicBlock *OuterAllocaBlock = OuterAllocaIP.getBlock();
904 
905   // Vector to remember instructions we used only during the modeling but which
906   // we want to delete at the end.
907   SmallVector<Instruction *, 4> ToBeDeleted;
908 
909   // Change the location to the outer alloca insertion point to create and
910   // initialize the allocas we pass into the parallel region.
911   Builder.restoreIP(OuterAllocaIP);
912   AllocaInst *TIDAddr = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
913   AllocaInst *ZeroAddr = Builder.CreateAlloca(Int32, nullptr, "zero.addr");
914 
915   // If there is an if condition we actually use the TIDAddr and ZeroAddr in the
916   // program, otherwise we only need them for modeling purposes to get the
917   // associated arguments in the outlined function. In the former case,
918   // initialize the allocas properly, in the latter case, delete them later.
919   if (IfCondition) {
920     Builder.CreateStore(Constant::getNullValue(Int32), TIDAddr);
921     Builder.CreateStore(Constant::getNullValue(Int32), ZeroAddr);
922   } else {
923     ToBeDeleted.push_back(TIDAddr);
924     ToBeDeleted.push_back(ZeroAddr);
925   }
926 
927   // Create an artificial insertion point that will also ensure the blocks we
928   // are about to split are not degenerated.
929   auto *UI = new UnreachableInst(Builder.getContext(), InsertBB);
930 
931   Instruction *ThenTI = UI, *ElseTI = nullptr;
932   if (IfCondition)
933     SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI);
934 
935   BasicBlock *ThenBB = ThenTI->getParent();
936   BasicBlock *PRegEntryBB = ThenBB->splitBasicBlock(ThenTI, "omp.par.entry");
937   BasicBlock *PRegBodyBB =
938       PRegEntryBB->splitBasicBlock(ThenTI, "omp.par.region");
939   BasicBlock *PRegPreFiniBB =
940       PRegBodyBB->splitBasicBlock(ThenTI, "omp.par.pre_finalize");
941   BasicBlock *PRegExitBB =
942       PRegPreFiniBB->splitBasicBlock(ThenTI, "omp.par.exit");
943 
944   auto FiniCBWrapper = [&](InsertPointTy IP) {
945     // Hide "open-ended" blocks from the given FiniCB by setting the right jump
946     // target to the region exit block.
947     if (IP.getBlock()->end() == IP.getPoint()) {
948       IRBuilder<>::InsertPointGuard IPG(Builder);
949       Builder.restoreIP(IP);
950       Instruction *I = Builder.CreateBr(PRegExitBB);
951       IP = InsertPointTy(I->getParent(), I->getIterator());
952     }
953     assert(IP.getBlock()->getTerminator()->getNumSuccessors() == 1 &&
954            IP.getBlock()->getTerminator()->getSuccessor(0) == PRegExitBB &&
955            "Unexpected insertion point for finalization call!");
956     return FiniCB(IP);
957   };
958 
959   FinalizationStack.push_back({FiniCBWrapper, OMPD_parallel, IsCancellable});
960 
961   // Generate the privatization allocas in the block that will become the entry
962   // of the outlined function.
963   Builder.SetInsertPoint(PRegEntryBB->getTerminator());
964   InsertPointTy InnerAllocaIP = Builder.saveIP();
965 
966   AllocaInst *PrivTIDAddr =
967       Builder.CreateAlloca(Int32, nullptr, "tid.addr.local");
968   Instruction *PrivTID = Builder.CreateLoad(Int32, PrivTIDAddr, "tid");
969 
970   // Add some fake uses for OpenMP provided arguments.
971   ToBeDeleted.push_back(Builder.CreateLoad(Int32, TIDAddr, "tid.addr.use"));
972   Instruction *ZeroAddrUse =
973       Builder.CreateLoad(Int32, ZeroAddr, "zero.addr.use");
974   ToBeDeleted.push_back(ZeroAddrUse);
975 
976   // ThenBB
977   //   |
978   //   V
979   // PRegionEntryBB         <- Privatization allocas are placed here.
980   //   |
981   //   V
982   // PRegionBodyBB          <- BodeGen is invoked here.
983   //   |
984   //   V
985   // PRegPreFiniBB          <- The block we will start finalization from.
986   //   |
987   //   V
988   // PRegionExitBB          <- A common exit to simplify block collection.
989   //
990 
991   LLVM_DEBUG(dbgs() << "Before body codegen: " << *OuterFn << "\n");
992 
993   // Let the caller create the body.
994   assert(BodyGenCB && "Expected body generation callback!");
995   InsertPointTy CodeGenIP(PRegBodyBB, PRegBodyBB->begin());
996   BodyGenCB(InnerAllocaIP, CodeGenIP);
997 
998   LLVM_DEBUG(dbgs() << "After  body codegen: " << *OuterFn << "\n");
999 
1000   FunctionCallee RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call);
1001   if (auto *F = dyn_cast<llvm::Function>(RTLFn.getCallee())) {
1002     if (!F->hasMetadata(llvm::LLVMContext::MD_callback)) {
1003       llvm::LLVMContext &Ctx = F->getContext();
1004       MDBuilder MDB(Ctx);
1005       // Annotate the callback behavior of the __kmpc_fork_call:
1006       //  - The callback callee is argument number 2 (microtask).
1007       //  - The first two arguments of the callback callee are unknown (-1).
1008       //  - All variadic arguments to the __kmpc_fork_call are passed to the
1009       //    callback callee.
1010       F->addMetadata(
1011           llvm::LLVMContext::MD_callback,
1012           *llvm::MDNode::get(
1013               Ctx, {MDB.createCallbackEncoding(2, {-1, -1},
1014                                                /* VarArgsArePassed */ true)}));
1015     }
1016   }
1017 
1018   OutlineInfo OI;
1019   OI.PostOutlineCB = [=](Function &OutlinedFn) {
1020     // Add some known attributes.
1021     OutlinedFn.addParamAttr(0, Attribute::NoAlias);
1022     OutlinedFn.addParamAttr(1, Attribute::NoAlias);
1023     OutlinedFn.addFnAttr(Attribute::NoUnwind);
1024     OutlinedFn.addFnAttr(Attribute::NoRecurse);
1025 
1026     assert(OutlinedFn.arg_size() >= 2 &&
1027            "Expected at least tid and bounded tid as arguments");
1028     unsigned NumCapturedVars =
1029         OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1030 
1031     CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
1032     CI->getParent()->setName("omp_parallel");
1033     Builder.SetInsertPoint(CI);
1034 
1035     // Build call __kmpc_fork_call(Ident, n, microtask, var1, .., varn);
1036     Value *ForkCallArgs[] = {
1037         Ident, Builder.getInt32(NumCapturedVars),
1038         Builder.CreateBitCast(&OutlinedFn, ParallelTaskPtr)};
1039 
1040     SmallVector<Value *, 16> RealArgs;
1041     RealArgs.append(std::begin(ForkCallArgs), std::end(ForkCallArgs));
1042     RealArgs.append(CI->arg_begin() + /* tid & bound tid */ 2, CI->arg_end());
1043 
1044     Builder.CreateCall(RTLFn, RealArgs);
1045 
1046     LLVM_DEBUG(dbgs() << "With fork_call placed: "
1047                       << *Builder.GetInsertBlock()->getParent() << "\n");
1048 
1049     InsertPointTy ExitIP(PRegExitBB, PRegExitBB->end());
1050 
1051     // Initialize the local TID stack location with the argument value.
1052     Builder.SetInsertPoint(PrivTID);
1053     Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1054     Builder.CreateStore(Builder.CreateLoad(Int32, OutlinedAI), PrivTIDAddr);
1055 
1056     // If no "if" clause was present we do not need the call created during
1057     // outlining, otherwise we reuse it in the serialized parallel region.
1058     if (!ElseTI) {
1059       CI->eraseFromParent();
1060     } else {
1061 
1062       // If an "if" clause was present we are now generating the serialized
1063       // version into the "else" branch.
1064       Builder.SetInsertPoint(ElseTI);
1065 
1066       // Build calls __kmpc_serialized_parallel(&Ident, GTid);
1067       Value *SerializedParallelCallArgs[] = {Ident, ThreadID};
1068       Builder.CreateCall(
1069           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_serialized_parallel),
1070           SerializedParallelCallArgs);
1071 
1072       // OutlinedFn(&GTid, &zero, CapturedStruct);
1073       CI->removeFromParent();
1074       Builder.Insert(CI);
1075 
1076       // __kmpc_end_serialized_parallel(&Ident, GTid);
1077       Value *EndArgs[] = {Ident, ThreadID};
1078       Builder.CreateCall(
1079           getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_serialized_parallel),
1080           EndArgs);
1081 
1082       LLVM_DEBUG(dbgs() << "With serialized parallel region: "
1083                         << *Builder.GetInsertBlock()->getParent() << "\n");
1084     }
1085 
1086     for (Instruction *I : ToBeDeleted)
1087       I->eraseFromParent();
1088   };
1089 
1090   // Adjust the finalization stack, verify the adjustment, and call the
1091   // finalize function a last time to finalize values between the pre-fini
1092   // block and the exit block if we left the parallel "the normal way".
1093   auto FiniInfo = FinalizationStack.pop_back_val();
1094   (void)FiniInfo;
1095   assert(FiniInfo.DK == OMPD_parallel &&
1096          "Unexpected finalization stack state!");
1097 
1098   Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
1099 
1100   InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
1101   FiniCB(PreFiniIP);
1102 
1103   OI.OuterAllocaBB = OuterAllocaBlock;
1104   OI.EntryBB = PRegEntryBB;
1105   OI.ExitBB = PRegExitBB;
1106 
1107   SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
1108   SmallVector<BasicBlock *, 32> Blocks;
1109   OI.collectBlocks(ParallelRegionBlockSet, Blocks);
1110 
1111   // Ensure a single exit node for the outlined region by creating one.
1112   // We might have multiple incoming edges to the exit now due to finalizations,
1113   // e.g., cancel calls that cause the control flow to leave the region.
1114   BasicBlock *PRegOutlinedExitBB = PRegExitBB;
1115   PRegExitBB = SplitBlock(PRegExitBB, &*PRegExitBB->getFirstInsertionPt());
1116   PRegOutlinedExitBB->setName("omp.par.outlined.exit");
1117   Blocks.push_back(PRegOutlinedExitBB);
1118 
1119   CodeExtractorAnalysisCache CEAC(*OuterFn);
1120   CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
1121                           /* AggregateArgs */ false,
1122                           /* BlockFrequencyInfo */ nullptr,
1123                           /* BranchProbabilityInfo */ nullptr,
1124                           /* AssumptionCache */ nullptr,
1125                           /* AllowVarArgs */ true,
1126                           /* AllowAlloca */ true,
1127                           /* AllocationBlock */ OuterAllocaBlock,
1128                           /* Suffix */ ".omp_par");
1129 
1130   // Find inputs to, outputs from the code region.
1131   BasicBlock *CommonExit = nullptr;
1132   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
1133   Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
1134   Extractor.findInputsOutputs(Inputs, Outputs, SinkingCands);
1135 
1136   LLVM_DEBUG(dbgs() << "Before privatization: " << *OuterFn << "\n");
1137 
1138   FunctionCallee TIDRTLFn =
1139       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num);
1140 
1141   auto PrivHelper = [&](Value &V) {
1142     if (&V == TIDAddr || &V == ZeroAddr) {
1143       OI.ExcludeArgsFromAggregate.push_back(&V);
1144       return;
1145     }
1146 
1147     SetVector<Use *> Uses;
1148     for (Use &U : V.uses())
1149       if (auto *UserI = dyn_cast<Instruction>(U.getUser()))
1150         if (ParallelRegionBlockSet.count(UserI->getParent()))
1151           Uses.insert(&U);
1152 
1153     // __kmpc_fork_call expects extra arguments as pointers. If the input
1154     // already has a pointer type, everything is fine. Otherwise, store the
1155     // value onto stack and load it back inside the to-be-outlined region. This
1156     // will ensure only the pointer will be passed to the function.
1157     // FIXME: if there are more than 15 trailing arguments, they must be
1158     // additionally packed in a struct.
1159     Value *Inner = &V;
1160     if (!V.getType()->isPointerTy()) {
1161       IRBuilder<>::InsertPointGuard Guard(Builder);
1162       LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n");
1163 
1164       Builder.restoreIP(OuterAllocaIP);
1165       Value *Ptr =
1166           Builder.CreateAlloca(V.getType(), nullptr, V.getName() + ".reloaded");
1167 
1168       // Store to stack at end of the block that currently branches to the entry
1169       // block of the to-be-outlined region.
1170       Builder.SetInsertPoint(InsertBB,
1171                              InsertBB->getTerminator()->getIterator());
1172       Builder.CreateStore(&V, Ptr);
1173 
1174       // Load back next to allocations in the to-be-outlined region.
1175       Builder.restoreIP(InnerAllocaIP);
1176       Inner = Builder.CreateLoad(V.getType(), Ptr);
1177     }
1178 
1179     Value *ReplacementValue = nullptr;
1180     CallInst *CI = dyn_cast<CallInst>(&V);
1181     if (CI && CI->getCalledFunction() == TIDRTLFn.getCallee()) {
1182       ReplacementValue = PrivTID;
1183     } else {
1184       Builder.restoreIP(
1185           PrivCB(InnerAllocaIP, Builder.saveIP(), V, *Inner, ReplacementValue));
1186       assert(ReplacementValue &&
1187              "Expected copy/create callback to set replacement value!");
1188       if (ReplacementValue == &V)
1189         return;
1190     }
1191 
1192     for (Use *UPtr : Uses)
1193       UPtr->set(ReplacementValue);
1194   };
1195 
1196   // Reset the inner alloca insertion as it will be used for loading the values
1197   // wrapped into pointers before passing them into the to-be-outlined region.
1198   // Configure it to insert immediately after the fake use of zero address so
1199   // that they are available in the generated body and so that the
1200   // OpenMP-related values (thread ID and zero address pointers) remain leading
1201   // in the argument list.
1202   InnerAllocaIP = IRBuilder<>::InsertPoint(
1203       ZeroAddrUse->getParent(), ZeroAddrUse->getNextNode()->getIterator());
1204 
1205   // Reset the outer alloca insertion point to the entry of the relevant block
1206   // in case it was invalidated.
1207   OuterAllocaIP = IRBuilder<>::InsertPoint(
1208       OuterAllocaBlock, OuterAllocaBlock->getFirstInsertionPt());
1209 
1210   for (Value *Input : Inputs) {
1211     LLVM_DEBUG(dbgs() << "Captured input: " << *Input << "\n");
1212     PrivHelper(*Input);
1213   }
1214   LLVM_DEBUG({
1215     for (Value *Output : Outputs)
1216       LLVM_DEBUG(dbgs() << "Captured output: " << *Output << "\n");
1217   });
1218   assert(Outputs.empty() &&
1219          "OpenMP outlining should not produce live-out values!");
1220 
1221   LLVM_DEBUG(dbgs() << "After  privatization: " << *OuterFn << "\n");
1222   LLVM_DEBUG({
1223     for (auto *BB : Blocks)
1224       dbgs() << " PBR: " << BB->getName() << "\n";
1225   });
1226 
1227   // Register the outlined info.
1228   addOutlineInfo(std::move(OI));
1229 
1230   InsertPointTy AfterIP(UI->getParent(), UI->getParent()->end());
1231   UI->eraseFromParent();
1232 
1233   return AfterIP;
1234 }
1235 
emitFlush(const LocationDescription & Loc)1236 void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) {
1237   // Build call void __kmpc_flush(ident_t *loc)
1238   uint32_t SrcLocStrSize;
1239   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1240   Value *Args[] = {getOrCreateIdent(SrcLocStr, SrcLocStrSize)};
1241 
1242   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_flush), Args);
1243 }
1244 
createFlush(const LocationDescription & Loc)1245 void OpenMPIRBuilder::createFlush(const LocationDescription &Loc) {
1246   if (!updateToLocation(Loc))
1247     return;
1248   emitFlush(Loc);
1249 }
1250 
emitTaskwaitImpl(const LocationDescription & Loc)1251 void OpenMPIRBuilder::emitTaskwaitImpl(const LocationDescription &Loc) {
1252   // Build call kmp_int32 __kmpc_omp_taskwait(ident_t *loc, kmp_int32
1253   // global_tid);
1254   uint32_t SrcLocStrSize;
1255   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1256   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1257   Value *Args[] = {Ident, getOrCreateThreadID(Ident)};
1258 
1259   // Ignore return result until untied tasks are supported.
1260   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskwait),
1261                      Args);
1262 }
1263 
createTaskwait(const LocationDescription & Loc)1264 void OpenMPIRBuilder::createTaskwait(const LocationDescription &Loc) {
1265   if (!updateToLocation(Loc))
1266     return;
1267   emitTaskwaitImpl(Loc);
1268 }
1269 
emitTaskyieldImpl(const LocationDescription & Loc)1270 void OpenMPIRBuilder::emitTaskyieldImpl(const LocationDescription &Loc) {
1271   // Build call __kmpc_omp_taskyield(loc, thread_id, 0);
1272   uint32_t SrcLocStrSize;
1273   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1274   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1275   Constant *I32Null = ConstantInt::getNullValue(Int32);
1276   Value *Args[] = {Ident, getOrCreateThreadID(Ident), I32Null};
1277 
1278   Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskyield),
1279                      Args);
1280 }
1281 
createTaskyield(const LocationDescription & Loc)1282 void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
1283   if (!updateToLocation(Loc))
1284     return;
1285   emitTaskyieldImpl(Loc);
1286 }
1287 
1288 OpenMPIRBuilder::InsertPointTy
createTask(const LocationDescription & Loc,InsertPointTy AllocaIP,BodyGenCallbackTy BodyGenCB,bool Tied,Value * Final)1289 OpenMPIRBuilder::createTask(const LocationDescription &Loc,
1290                             InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
1291                             bool Tied, Value *Final) {
1292   if (!updateToLocation(Loc))
1293     return InsertPointTy();
1294 
1295   uint32_t SrcLocStrSize;
1296   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1297   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1298   // The current basic block is split into four basic blocks. After outlining,
1299   // they will be mapped as follows:
1300   // ```
1301   // def current_fn() {
1302   //   current_basic_block:
1303   //     br label %task.exit
1304   //   task.exit:
1305   //     ; instructions after task
1306   // }
1307   // def outlined_fn() {
1308   //   task.alloca:
1309   //     br label %task.body
1310   //   task.body:
1311   //     ret void
1312   // }
1313   // ```
1314   BasicBlock *TaskExitBB = splitBB(Builder, /*CreateBranch=*/true, "task.exit");
1315   BasicBlock *TaskBodyBB = splitBB(Builder, /*CreateBranch=*/true, "task.body");
1316   BasicBlock *TaskAllocaBB =
1317       splitBB(Builder, /*CreateBranch=*/true, "task.alloca");
1318 
1319   OutlineInfo OI;
1320   OI.EntryBB = TaskAllocaBB;
1321   OI.OuterAllocaBB = AllocaIP.getBlock();
1322   OI.ExitBB = TaskExitBB;
1323   OI.PostOutlineCB = [this, Ident, Tied, Final](Function &OutlinedFn) {
1324     // The input IR here looks like the following-
1325     // ```
1326     // func @current_fn() {
1327     //   outlined_fn(%args)
1328     // }
1329     // func @outlined_fn(%args) { ... }
1330     // ```
1331     //
1332     // This is changed to the following-
1333     //
1334     // ```
1335     // func @current_fn() {
1336     //   runtime_call(..., wrapper_fn, ...)
1337     // }
1338     // func @wrapper_fn(..., %args) {
1339     //   outlined_fn(%args)
1340     // }
1341     // func @outlined_fn(%args) { ... }
1342     // ```
1343 
1344     // The stale call instruction will be replaced with a new call instruction
1345     // for runtime call with a wrapper function.
1346     assert(OutlinedFn.getNumUses() == 1 &&
1347            "there must be a single user for the outlined function");
1348     CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
1349 
1350     // HasTaskData is true if any variables are captured in the outlined region,
1351     // false otherwise.
1352     bool HasTaskData = StaleCI->arg_size() > 0;
1353     Builder.SetInsertPoint(StaleCI);
1354 
1355     // Gather the arguments for emitting the runtime call for
1356     // @__kmpc_omp_task_alloc
1357     Function *TaskAllocFn =
1358         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc);
1359 
1360     // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
1361     // call.
1362     Value *ThreadID = getOrCreateThreadID(Ident);
1363 
1364     // Argument - `flags`
1365     // Task is tied iff (Flags & 1) == 1.
1366     // Task is untied iff (Flags & 1) == 0.
1367     // Task is final iff (Flags & 2) == 2.
1368     // Task is not final iff (Flags & 2) == 0.
1369     // TODO: Handle the other flags.
1370     Value *Flags = Builder.getInt32(Tied);
1371     if (Final) {
1372       Value *FinalFlag =
1373           Builder.CreateSelect(Final, Builder.getInt32(2), Builder.getInt32(0));
1374       Flags = Builder.CreateOr(FinalFlag, Flags);
1375     }
1376 
1377     // Argument - `sizeof_kmp_task_t` (TaskSize)
1378     // Tasksize refers to the size in bytes of kmp_task_t data structure
1379     // including private vars accessed in task.
1380     Value *TaskSize = Builder.getInt64(0);
1381     if (HasTaskData) {
1382       AllocaInst *ArgStructAlloca =
1383           dyn_cast<AllocaInst>(StaleCI->getArgOperand(0));
1384       assert(ArgStructAlloca &&
1385              "Unable to find the alloca instruction corresponding to arguments "
1386              "for extracted function");
1387       StructType *ArgStructType =
1388           dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
1389       assert(ArgStructType && "Unable to find struct type corresponding to "
1390                               "arguments for extracted function");
1391       TaskSize =
1392           Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
1393     }
1394 
1395     // TODO: Argument - sizeof_shareds
1396 
1397     // Argument - task_entry (the wrapper function)
1398     // If the outlined function has some captured variables (i.e. HasTaskData is
1399     // true), then the wrapper function will have an additional argument (the
1400     // struct containing captured variables). Otherwise, no such argument will
1401     // be present.
1402     SmallVector<Type *> WrapperArgTys{Builder.getInt32Ty()};
1403     if (HasTaskData)
1404       WrapperArgTys.push_back(OutlinedFn.getArg(0)->getType());
1405     FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
1406         (Twine(OutlinedFn.getName()) + ".wrapper").str(),
1407         FunctionType::get(Builder.getInt32Ty(), WrapperArgTys, false));
1408     Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
1409     PointerType *WrapperFuncBitcastType =
1410         FunctionType::get(Builder.getInt32Ty(),
1411                           {Builder.getInt32Ty(), Builder.getInt8PtrTy()}, false)
1412             ->getPointerTo();
1413     Value *WrapperFuncBitcast =
1414         ConstantExpr::getBitCast(WrapperFunc, WrapperFuncBitcastType);
1415 
1416     // Emit the @__kmpc_omp_task_alloc runtime call
1417     // The runtime call returns a pointer to an area where the task captured
1418     // variables must be copied before the task is run (NewTaskData)
1419     CallInst *NewTaskData = Builder.CreateCall(
1420         TaskAllocFn,
1421         {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
1422          /*sizeof_task=*/TaskSize, /*sizeof_shared=*/Builder.getInt64(0),
1423          /*task_func=*/WrapperFuncBitcast});
1424 
1425     // Copy the arguments for outlined function
1426     if (HasTaskData) {
1427       Value *TaskData = StaleCI->getArgOperand(0);
1428       Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
1429       Builder.CreateMemCpy(NewTaskData, Alignment, TaskData, Alignment,
1430                            TaskSize);
1431     }
1432 
1433     // Emit the @__kmpc_omp_task runtime call to spawn the task
1434     Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
1435     Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData});
1436 
1437     StaleCI->eraseFromParent();
1438 
1439     // Emit the body for wrapper function
1440     BasicBlock *WrapperEntryBB =
1441         BasicBlock::Create(M.getContext(), "", WrapperFunc);
1442     Builder.SetInsertPoint(WrapperEntryBB);
1443     if (HasTaskData)
1444       Builder.CreateCall(&OutlinedFn, {WrapperFunc->getArg(1)});
1445     else
1446       Builder.CreateCall(&OutlinedFn);
1447     Builder.CreateRet(Builder.getInt32(0));
1448   };
1449 
1450   addOutlineInfo(std::move(OI));
1451 
1452   InsertPointTy TaskAllocaIP =
1453       InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
1454   InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
1455   BodyGenCB(TaskAllocaIP, TaskBodyIP);
1456   Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin());
1457 
1458   return Builder.saveIP();
1459 }
1460 
1461 OpenMPIRBuilder::InsertPointTy
createTaskgroup(const LocationDescription & Loc,InsertPointTy AllocaIP,BodyGenCallbackTy BodyGenCB)1462 OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
1463                                  InsertPointTy AllocaIP,
1464                                  BodyGenCallbackTy BodyGenCB) {
1465   if (!updateToLocation(Loc))
1466     return InsertPointTy();
1467 
1468   uint32_t SrcLocStrSize;
1469   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1470   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1471   Value *ThreadID = getOrCreateThreadID(Ident);
1472 
1473   // Emit the @__kmpc_taskgroup runtime call to start the taskgroup
1474   Function *TaskgroupFn =
1475       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_taskgroup);
1476   Builder.CreateCall(TaskgroupFn, {Ident, ThreadID});
1477 
1478   BasicBlock *TaskgroupExitBB = splitBB(Builder, true, "taskgroup.exit");
1479   BodyGenCB(AllocaIP, Builder.saveIP());
1480 
1481   Builder.SetInsertPoint(TaskgroupExitBB);
1482   // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup
1483   Function *EndTaskgroupFn =
1484       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_taskgroup);
1485   Builder.CreateCall(EndTaskgroupFn, {Ident, ThreadID});
1486 
1487   return Builder.saveIP();
1488 }
1489 
createSections(const LocationDescription & Loc,InsertPointTy AllocaIP,ArrayRef<StorableBodyGenCallbackTy> SectionCBs,PrivatizeCallbackTy PrivCB,FinalizeCallbackTy FiniCB,bool IsCancellable,bool IsNowait)1490 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSections(
1491     const LocationDescription &Loc, InsertPointTy AllocaIP,
1492     ArrayRef<StorableBodyGenCallbackTy> SectionCBs, PrivatizeCallbackTy PrivCB,
1493     FinalizeCallbackTy FiniCB, bool IsCancellable, bool IsNowait) {
1494   assert(!isConflictIP(AllocaIP, Loc.IP) && "Dedicated IP allocas required");
1495 
1496   if (!updateToLocation(Loc))
1497     return Loc.IP;
1498 
1499   auto FiniCBWrapper = [&](InsertPointTy IP) {
1500     if (IP.getBlock()->end() != IP.getPoint())
1501       return FiniCB(IP);
1502     // This must be done otherwise any nested constructs using FinalizeOMPRegion
1503     // will fail because that function requires the Finalization Basic Block to
1504     // have a terminator, which is already removed by EmitOMPRegionBody.
1505     // IP is currently at cancelation block.
1506     // We need to backtrack to the condition block to fetch
1507     // the exit block and create a branch from cancelation
1508     // to exit block.
1509     IRBuilder<>::InsertPointGuard IPG(Builder);
1510     Builder.restoreIP(IP);
1511     auto *CaseBB = IP.getBlock()->getSinglePredecessor();
1512     auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
1513     auto *ExitBB = CondBB->getTerminator()->getSuccessor(1);
1514     Instruction *I = Builder.CreateBr(ExitBB);
1515     IP = InsertPointTy(I->getParent(), I->getIterator());
1516     return FiniCB(IP);
1517   };
1518 
1519   FinalizationStack.push_back({FiniCBWrapper, OMPD_sections, IsCancellable});
1520 
1521   // Each section is emitted as a switch case
1522   // Each finalization callback is handled from clang.EmitOMPSectionDirective()
1523   // -> OMP.createSection() which generates the IR for each section
1524   // Iterate through all sections and emit a switch construct:
1525   // switch (IV) {
1526   //   case 0:
1527   //     <SectionStmt[0]>;
1528   //     break;
1529   // ...
1530   //   case <NumSection> - 1:
1531   //     <SectionStmt[<NumSection> - 1]>;
1532   //     break;
1533   // }
1534   // ...
1535   // section_loop.after:
1536   // <FiniCB>;
1537   auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, Value *IndVar) {
1538     Builder.restoreIP(CodeGenIP);
1539     BasicBlock *Continue =
1540         splitBBWithSuffix(Builder, /*CreateBranch=*/false, ".sections.after");
1541     Function *CurFn = Continue->getParent();
1542     SwitchInst *SwitchStmt = Builder.CreateSwitch(IndVar, Continue);
1543 
1544     unsigned CaseNumber = 0;
1545     for (auto SectionCB : SectionCBs) {
1546       BasicBlock *CaseBB = BasicBlock::Create(
1547           M.getContext(), "omp_section_loop.body.case", CurFn, Continue);
1548       SwitchStmt->addCase(Builder.getInt32(CaseNumber), CaseBB);
1549       Builder.SetInsertPoint(CaseBB);
1550       BranchInst *CaseEndBr = Builder.CreateBr(Continue);
1551       SectionCB(InsertPointTy(),
1552                 {CaseEndBr->getParent(), CaseEndBr->getIterator()});
1553       CaseNumber++;
1554     }
1555     // remove the existing terminator from body BB since there can be no
1556     // terminators after switch/case
1557   };
1558   // Loop body ends here
1559   // LowerBound, UpperBound, and STride for createCanonicalLoop
1560   Type *I32Ty = Type::getInt32Ty(M.getContext());
1561   Value *LB = ConstantInt::get(I32Ty, 0);
1562   Value *UB = ConstantInt::get(I32Ty, SectionCBs.size());
1563   Value *ST = ConstantInt::get(I32Ty, 1);
1564   llvm::CanonicalLoopInfo *LoopInfo = createCanonicalLoop(
1565       Loc, LoopBodyGenCB, LB, UB, ST, true, false, AllocaIP, "section_loop");
1566   InsertPointTy AfterIP =
1567       applyStaticWorkshareLoop(Loc.DL, LoopInfo, AllocaIP, !IsNowait);
1568 
1569   // Apply the finalization callback in LoopAfterBB
1570   auto FiniInfo = FinalizationStack.pop_back_val();
1571   assert(FiniInfo.DK == OMPD_sections &&
1572          "Unexpected finalization stack state!");
1573   if (FinalizeCallbackTy &CB = FiniInfo.FiniCB) {
1574     Builder.restoreIP(AfterIP);
1575     BasicBlock *FiniBB =
1576         splitBBWithSuffix(Builder, /*CreateBranch=*/true, "sections.fini");
1577     CB(Builder.saveIP());
1578     AfterIP = {FiniBB, FiniBB->begin()};
1579   }
1580 
1581   return AfterIP;
1582 }
1583 
1584 OpenMPIRBuilder::InsertPointTy
createSection(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB)1585 OpenMPIRBuilder::createSection(const LocationDescription &Loc,
1586                                BodyGenCallbackTy BodyGenCB,
1587                                FinalizeCallbackTy FiniCB) {
1588   if (!updateToLocation(Loc))
1589     return Loc.IP;
1590 
1591   auto FiniCBWrapper = [&](InsertPointTy IP) {
1592     if (IP.getBlock()->end() != IP.getPoint())
1593       return FiniCB(IP);
1594     // This must be done otherwise any nested constructs using FinalizeOMPRegion
1595     // will fail because that function requires the Finalization Basic Block to
1596     // have a terminator, which is already removed by EmitOMPRegionBody.
1597     // IP is currently at cancelation block.
1598     // We need to backtrack to the condition block to fetch
1599     // the exit block and create a branch from cancelation
1600     // to exit block.
1601     IRBuilder<>::InsertPointGuard IPG(Builder);
1602     Builder.restoreIP(IP);
1603     auto *CaseBB = Loc.IP.getBlock();
1604     auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
1605     auto *ExitBB = CondBB->getTerminator()->getSuccessor(1);
1606     Instruction *I = Builder.CreateBr(ExitBB);
1607     IP = InsertPointTy(I->getParent(), I->getIterator());
1608     return FiniCB(IP);
1609   };
1610 
1611   Directive OMPD = Directive::OMPD_sections;
1612   // Since we are using Finalization Callback here, HasFinalize
1613   // and IsCancellable have to be true
1614   return EmitOMPInlinedRegion(OMPD, nullptr, nullptr, BodyGenCB, FiniCBWrapper,
1615                               /*Conditional*/ false, /*hasFinalize*/ true,
1616                               /*IsCancellable*/ true);
1617 }
1618 
1619 /// Create a function with a unique name and a "void (i8*, i8*)" signature in
1620 /// the given module and return it.
getFreshReductionFunc(Module & M)1621 Function *getFreshReductionFunc(Module &M) {
1622   Type *VoidTy = Type::getVoidTy(M.getContext());
1623   Type *Int8PtrTy = Type::getInt8PtrTy(M.getContext());
1624   auto *FuncTy =
1625       FunctionType::get(VoidTy, {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ false);
1626   return Function::Create(FuncTy, GlobalVariable::InternalLinkage,
1627                           M.getDataLayout().getDefaultGlobalsAddressSpace(),
1628                           ".omp.reduction.func", &M);
1629 }
1630 
createReductions(const LocationDescription & Loc,InsertPointTy AllocaIP,ArrayRef<ReductionInfo> ReductionInfos,bool IsNoWait)1631 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions(
1632     const LocationDescription &Loc, InsertPointTy AllocaIP,
1633     ArrayRef<ReductionInfo> ReductionInfos, bool IsNoWait) {
1634   for (const ReductionInfo &RI : ReductionInfos) {
1635     (void)RI;
1636     assert(RI.Variable && "expected non-null variable");
1637     assert(RI.PrivateVariable && "expected non-null private variable");
1638     assert(RI.ReductionGen && "expected non-null reduction generator callback");
1639     assert(RI.Variable->getType() == RI.PrivateVariable->getType() &&
1640            "expected variables and their private equivalents to have the same "
1641            "type");
1642     assert(RI.Variable->getType()->isPointerTy() &&
1643            "expected variables to be pointers");
1644   }
1645 
1646   if (!updateToLocation(Loc))
1647     return InsertPointTy();
1648 
1649   BasicBlock *InsertBlock = Loc.IP.getBlock();
1650   BasicBlock *ContinuationBlock =
1651       InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
1652   InsertBlock->getTerminator()->eraseFromParent();
1653 
1654   // Create and populate array of type-erased pointers to private reduction
1655   // values.
1656   unsigned NumReductions = ReductionInfos.size();
1657   Type *RedArrayTy = ArrayType::get(Builder.getInt8PtrTy(), NumReductions);
1658   Builder.restoreIP(AllocaIP);
1659   Value *RedArray = Builder.CreateAlloca(RedArrayTy, nullptr, "red.array");
1660 
1661   Builder.SetInsertPoint(InsertBlock, InsertBlock->end());
1662 
1663   for (auto En : enumerate(ReductionInfos)) {
1664     unsigned Index = En.index();
1665     const ReductionInfo &RI = En.value();
1666     Value *RedArrayElemPtr = Builder.CreateConstInBoundsGEP2_64(
1667         RedArrayTy, RedArray, 0, Index, "red.array.elem." + Twine(Index));
1668     Value *Casted =
1669         Builder.CreateBitCast(RI.PrivateVariable, Builder.getInt8PtrTy(),
1670                               "private.red.var." + Twine(Index) + ".casted");
1671     Builder.CreateStore(Casted, RedArrayElemPtr);
1672   }
1673 
1674   // Emit a call to the runtime function that orchestrates the reduction.
1675   // Declare the reduction function in the process.
1676   Function *Func = Builder.GetInsertBlock()->getParent();
1677   Module *Module = Func->getParent();
1678   Value *RedArrayPtr =
1679       Builder.CreateBitCast(RedArray, Builder.getInt8PtrTy(), "red.array.ptr");
1680   uint32_t SrcLocStrSize;
1681   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1682   bool CanGenerateAtomic =
1683       llvm::all_of(ReductionInfos, [](const ReductionInfo &RI) {
1684         return RI.AtomicReductionGen;
1685       });
1686   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize,
1687                                   CanGenerateAtomic
1688                                       ? IdentFlag::OMP_IDENT_FLAG_ATOMIC_REDUCE
1689                                       : IdentFlag(0));
1690   Value *ThreadId = getOrCreateThreadID(Ident);
1691   Constant *NumVariables = Builder.getInt32(NumReductions);
1692   const DataLayout &DL = Module->getDataLayout();
1693   unsigned RedArrayByteSize = DL.getTypeStoreSize(RedArrayTy);
1694   Constant *RedArraySize = Builder.getInt64(RedArrayByteSize);
1695   Function *ReductionFunc = getFreshReductionFunc(*Module);
1696   Value *Lock = getOMPCriticalRegionLock(".reduction");
1697   Function *ReduceFunc = getOrCreateRuntimeFunctionPtr(
1698       IsNoWait ? RuntimeFunction::OMPRTL___kmpc_reduce_nowait
1699                : RuntimeFunction::OMPRTL___kmpc_reduce);
1700   CallInst *ReduceCall =
1701       Builder.CreateCall(ReduceFunc,
1702                          {Ident, ThreadId, NumVariables, RedArraySize,
1703                           RedArrayPtr, ReductionFunc, Lock},
1704                          "reduce");
1705 
1706   // Create final reduction entry blocks for the atomic and non-atomic case.
1707   // Emit IR that dispatches control flow to one of the blocks based on the
1708   // reduction supporting the atomic mode.
1709   BasicBlock *NonAtomicRedBlock =
1710       BasicBlock::Create(Module->getContext(), "reduce.switch.nonatomic", Func);
1711   BasicBlock *AtomicRedBlock =
1712       BasicBlock::Create(Module->getContext(), "reduce.switch.atomic", Func);
1713   SwitchInst *Switch =
1714       Builder.CreateSwitch(ReduceCall, ContinuationBlock, /* NumCases */ 2);
1715   Switch->addCase(Builder.getInt32(1), NonAtomicRedBlock);
1716   Switch->addCase(Builder.getInt32(2), AtomicRedBlock);
1717 
1718   // Populate the non-atomic reduction using the elementwise reduction function.
1719   // This loads the elements from the global and private variables and reduces
1720   // them before storing back the result to the global variable.
1721   Builder.SetInsertPoint(NonAtomicRedBlock);
1722   for (auto En : enumerate(ReductionInfos)) {
1723     const ReductionInfo &RI = En.value();
1724     Type *ValueType = RI.ElementType;
1725     Value *RedValue = Builder.CreateLoad(ValueType, RI.Variable,
1726                                          "red.value." + Twine(En.index()));
1727     Value *PrivateRedValue =
1728         Builder.CreateLoad(ValueType, RI.PrivateVariable,
1729                            "red.private.value." + Twine(En.index()));
1730     Value *Reduced;
1731     Builder.restoreIP(
1732         RI.ReductionGen(Builder.saveIP(), RedValue, PrivateRedValue, Reduced));
1733     if (!Builder.GetInsertBlock())
1734       return InsertPointTy();
1735     Builder.CreateStore(Reduced, RI.Variable);
1736   }
1737   Function *EndReduceFunc = getOrCreateRuntimeFunctionPtr(
1738       IsNoWait ? RuntimeFunction::OMPRTL___kmpc_end_reduce_nowait
1739                : RuntimeFunction::OMPRTL___kmpc_end_reduce);
1740   Builder.CreateCall(EndReduceFunc, {Ident, ThreadId, Lock});
1741   Builder.CreateBr(ContinuationBlock);
1742 
1743   // Populate the atomic reduction using the atomic elementwise reduction
1744   // function. There are no loads/stores here because they will be happening
1745   // inside the atomic elementwise reduction.
1746   Builder.SetInsertPoint(AtomicRedBlock);
1747   if (CanGenerateAtomic) {
1748     for (const ReductionInfo &RI : ReductionInfos) {
1749       Builder.restoreIP(RI.AtomicReductionGen(Builder.saveIP(), RI.ElementType,
1750                                               RI.Variable, RI.PrivateVariable));
1751       if (!Builder.GetInsertBlock())
1752         return InsertPointTy();
1753     }
1754     Builder.CreateBr(ContinuationBlock);
1755   } else {
1756     Builder.CreateUnreachable();
1757   }
1758 
1759   // Populate the outlined reduction function using the elementwise reduction
1760   // function. Partial values are extracted from the type-erased array of
1761   // pointers to private variables.
1762   BasicBlock *ReductionFuncBlock =
1763       BasicBlock::Create(Module->getContext(), "", ReductionFunc);
1764   Builder.SetInsertPoint(ReductionFuncBlock);
1765   Value *LHSArrayPtr = Builder.CreateBitCast(ReductionFunc->getArg(0),
1766                                              RedArrayTy->getPointerTo());
1767   Value *RHSArrayPtr = Builder.CreateBitCast(ReductionFunc->getArg(1),
1768                                              RedArrayTy->getPointerTo());
1769   for (auto En : enumerate(ReductionInfos)) {
1770     const ReductionInfo &RI = En.value();
1771     Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
1772         RedArrayTy, LHSArrayPtr, 0, En.index());
1773     Value *LHSI8Ptr = Builder.CreateLoad(Builder.getInt8PtrTy(), LHSI8PtrPtr);
1774     Value *LHSPtr = Builder.CreateBitCast(LHSI8Ptr, RI.Variable->getType());
1775     Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
1776     Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
1777         RedArrayTy, RHSArrayPtr, 0, En.index());
1778     Value *RHSI8Ptr = Builder.CreateLoad(Builder.getInt8PtrTy(), RHSI8PtrPtr);
1779     Value *RHSPtr =
1780         Builder.CreateBitCast(RHSI8Ptr, RI.PrivateVariable->getType());
1781     Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
1782     Value *Reduced;
1783     Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced));
1784     if (!Builder.GetInsertBlock())
1785       return InsertPointTy();
1786     Builder.CreateStore(Reduced, LHSPtr);
1787   }
1788   Builder.CreateRetVoid();
1789 
1790   Builder.SetInsertPoint(ContinuationBlock);
1791   return Builder.saveIP();
1792 }
1793 
1794 OpenMPIRBuilder::InsertPointTy
createMaster(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB)1795 OpenMPIRBuilder::createMaster(const LocationDescription &Loc,
1796                               BodyGenCallbackTy BodyGenCB,
1797                               FinalizeCallbackTy FiniCB) {
1798 
1799   if (!updateToLocation(Loc))
1800     return Loc.IP;
1801 
1802   Directive OMPD = Directive::OMPD_master;
1803   uint32_t SrcLocStrSize;
1804   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1805   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1806   Value *ThreadId = getOrCreateThreadID(Ident);
1807   Value *Args[] = {Ident, ThreadId};
1808 
1809   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_master);
1810   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
1811 
1812   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_master);
1813   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
1814 
1815   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
1816                               /*Conditional*/ true, /*hasFinalize*/ true);
1817 }
1818 
1819 OpenMPIRBuilder::InsertPointTy
createMasked(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,Value * Filter)1820 OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
1821                               BodyGenCallbackTy BodyGenCB,
1822                               FinalizeCallbackTy FiniCB, Value *Filter) {
1823   if (!updateToLocation(Loc))
1824     return Loc.IP;
1825 
1826   Directive OMPD = Directive::OMPD_masked;
1827   uint32_t SrcLocStrSize;
1828   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1829   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1830   Value *ThreadId = getOrCreateThreadID(Ident);
1831   Value *Args[] = {Ident, ThreadId, Filter};
1832   Value *ArgsEnd[] = {Ident, ThreadId};
1833 
1834   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_masked);
1835   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
1836 
1837   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_masked);
1838   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, ArgsEnd);
1839 
1840   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
1841                               /*Conditional*/ true, /*hasFinalize*/ true);
1842 }
1843 
createLoopSkeleton(DebugLoc DL,Value * TripCount,Function * F,BasicBlock * PreInsertBefore,BasicBlock * PostInsertBefore,const Twine & Name)1844 CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
1845     DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
1846     BasicBlock *PostInsertBefore, const Twine &Name) {
1847   Module *M = F->getParent();
1848   LLVMContext &Ctx = M->getContext();
1849   Type *IndVarTy = TripCount->getType();
1850 
1851   // Create the basic block structure.
1852   BasicBlock *Preheader =
1853       BasicBlock::Create(Ctx, "omp_" + Name + ".preheader", F, PreInsertBefore);
1854   BasicBlock *Header =
1855       BasicBlock::Create(Ctx, "omp_" + Name + ".header", F, PreInsertBefore);
1856   BasicBlock *Cond =
1857       BasicBlock::Create(Ctx, "omp_" + Name + ".cond", F, PreInsertBefore);
1858   BasicBlock *Body =
1859       BasicBlock::Create(Ctx, "omp_" + Name + ".body", F, PreInsertBefore);
1860   BasicBlock *Latch =
1861       BasicBlock::Create(Ctx, "omp_" + Name + ".inc", F, PostInsertBefore);
1862   BasicBlock *Exit =
1863       BasicBlock::Create(Ctx, "omp_" + Name + ".exit", F, PostInsertBefore);
1864   BasicBlock *After =
1865       BasicBlock::Create(Ctx, "omp_" + Name + ".after", F, PostInsertBefore);
1866 
1867   // Use specified DebugLoc for new instructions.
1868   Builder.SetCurrentDebugLocation(DL);
1869 
1870   Builder.SetInsertPoint(Preheader);
1871   Builder.CreateBr(Header);
1872 
1873   Builder.SetInsertPoint(Header);
1874   PHINode *IndVarPHI = Builder.CreatePHI(IndVarTy, 2, "omp_" + Name + ".iv");
1875   IndVarPHI->addIncoming(ConstantInt::get(IndVarTy, 0), Preheader);
1876   Builder.CreateBr(Cond);
1877 
1878   Builder.SetInsertPoint(Cond);
1879   Value *Cmp =
1880       Builder.CreateICmpULT(IndVarPHI, TripCount, "omp_" + Name + ".cmp");
1881   Builder.CreateCondBr(Cmp, Body, Exit);
1882 
1883   Builder.SetInsertPoint(Body);
1884   Builder.CreateBr(Latch);
1885 
1886   Builder.SetInsertPoint(Latch);
1887   Value *Next = Builder.CreateAdd(IndVarPHI, ConstantInt::get(IndVarTy, 1),
1888                                   "omp_" + Name + ".next", /*HasNUW=*/true);
1889   Builder.CreateBr(Header);
1890   IndVarPHI->addIncoming(Next, Latch);
1891 
1892   Builder.SetInsertPoint(Exit);
1893   Builder.CreateBr(After);
1894 
1895   // Remember and return the canonical control flow.
1896   LoopInfos.emplace_front();
1897   CanonicalLoopInfo *CL = &LoopInfos.front();
1898 
1899   CL->Header = Header;
1900   CL->Cond = Cond;
1901   CL->Latch = Latch;
1902   CL->Exit = Exit;
1903 
1904 #ifndef NDEBUG
1905   CL->assertOK();
1906 #endif
1907   return CL;
1908 }
1909 
1910 CanonicalLoopInfo *
createCanonicalLoop(const LocationDescription & Loc,LoopBodyGenCallbackTy BodyGenCB,Value * TripCount,const Twine & Name)1911 OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
1912                                      LoopBodyGenCallbackTy BodyGenCB,
1913                                      Value *TripCount, const Twine &Name) {
1914   BasicBlock *BB = Loc.IP.getBlock();
1915   BasicBlock *NextBB = BB->getNextNode();
1916 
1917   CanonicalLoopInfo *CL = createLoopSkeleton(Loc.DL, TripCount, BB->getParent(),
1918                                              NextBB, NextBB, Name);
1919   BasicBlock *After = CL->getAfter();
1920 
1921   // If location is not set, don't connect the loop.
1922   if (updateToLocation(Loc)) {
1923     // Split the loop at the insertion point: Branch to the preheader and move
1924     // every following instruction to after the loop (the After BB). Also, the
1925     // new successor is the loop's after block.
1926     spliceBB(Builder, After, /*CreateBranch=*/false);
1927     Builder.CreateBr(CL->getPreheader());
1928   }
1929 
1930   // Emit the body content. We do it after connecting the loop to the CFG to
1931   // avoid that the callback encounters degenerate BBs.
1932   BodyGenCB(CL->getBodyIP(), CL->getIndVar());
1933 
1934 #ifndef NDEBUG
1935   CL->assertOK();
1936 #endif
1937   return CL;
1938 }
1939 
createCanonicalLoop(const LocationDescription & Loc,LoopBodyGenCallbackTy BodyGenCB,Value * Start,Value * Stop,Value * Step,bool IsSigned,bool InclusiveStop,InsertPointTy ComputeIP,const Twine & Name)1940 CanonicalLoopInfo *OpenMPIRBuilder::createCanonicalLoop(
1941     const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
1942     Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
1943     InsertPointTy ComputeIP, const Twine &Name) {
1944 
1945   // Consider the following difficulties (assuming 8-bit signed integers):
1946   //  * Adding \p Step to the loop counter which passes \p Stop may overflow:
1947   //      DO I = 1, 100, 50
1948   ///  * A \p Step of INT_MIN cannot not be normalized to a positive direction:
1949   //      DO I = 100, 0, -128
1950 
1951   // Start, Stop and Step must be of the same integer type.
1952   auto *IndVarTy = cast<IntegerType>(Start->getType());
1953   assert(IndVarTy == Stop->getType() && "Stop type mismatch");
1954   assert(IndVarTy == Step->getType() && "Step type mismatch");
1955 
1956   LocationDescription ComputeLoc =
1957       ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
1958   updateToLocation(ComputeLoc);
1959 
1960   ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
1961   ConstantInt *One = ConstantInt::get(IndVarTy, 1);
1962 
1963   // Like Step, but always positive.
1964   Value *Incr = Step;
1965 
1966   // Distance between Start and Stop; always positive.
1967   Value *Span;
1968 
1969   // Condition whether there are no iterations are executed at all, e.g. because
1970   // UB < LB.
1971   Value *ZeroCmp;
1972 
1973   if (IsSigned) {
1974     // Ensure that increment is positive. If not, negate and invert LB and UB.
1975     Value *IsNeg = Builder.CreateICmpSLT(Step, Zero);
1976     Incr = Builder.CreateSelect(IsNeg, Builder.CreateNeg(Step), Step);
1977     Value *LB = Builder.CreateSelect(IsNeg, Stop, Start);
1978     Value *UB = Builder.CreateSelect(IsNeg, Start, Stop);
1979     Span = Builder.CreateSub(UB, LB, "", false, true);
1980     ZeroCmp = Builder.CreateICmp(
1981         InclusiveStop ? CmpInst::ICMP_SLT : CmpInst::ICMP_SLE, UB, LB);
1982   } else {
1983     Span = Builder.CreateSub(Stop, Start, "", true);
1984     ZeroCmp = Builder.CreateICmp(
1985         InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, Stop, Start);
1986   }
1987 
1988   Value *CountIfLooping;
1989   if (InclusiveStop) {
1990     CountIfLooping = Builder.CreateAdd(Builder.CreateUDiv(Span, Incr), One);
1991   } else {
1992     // Avoid incrementing past stop since it could overflow.
1993     Value *CountIfTwo = Builder.CreateAdd(
1994         Builder.CreateUDiv(Builder.CreateSub(Span, One), Incr), One);
1995     Value *OneCmp = Builder.CreateICmp(
1996         InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, Span, Incr);
1997     CountIfLooping = Builder.CreateSelect(OneCmp, One, CountIfTwo);
1998   }
1999   Value *TripCount = Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
2000                                           "omp_" + Name + ".tripcount");
2001 
2002   auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
2003     Builder.restoreIP(CodeGenIP);
2004     Value *Span = Builder.CreateMul(IV, Step);
2005     Value *IndVar = Builder.CreateAdd(Span, Start);
2006     BodyGenCB(Builder.saveIP(), IndVar);
2007   };
2008   LocationDescription LoopLoc = ComputeIP.isSet() ? Loc.IP : Builder.saveIP();
2009   return createCanonicalLoop(LoopLoc, BodyGen, TripCount, Name);
2010 }
2011 
2012 // Returns an LLVM function to call for initializing loop bounds using OpenMP
2013 // static scheduling depending on `type`. Only i32 and i64 are supported by the
2014 // runtime. Always interpret integers as unsigned similarly to
2015 // CanonicalLoopInfo.
getKmpcForStaticInitForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)2016 static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
2017                                                   OpenMPIRBuilder &OMPBuilder) {
2018   unsigned Bitwidth = Ty->getIntegerBitWidth();
2019   if (Bitwidth == 32)
2020     return OMPBuilder.getOrCreateRuntimeFunction(
2021         M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_4u);
2022   if (Bitwidth == 64)
2023     return OMPBuilder.getOrCreateRuntimeFunction(
2024         M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_8u);
2025   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
2026 }
2027 
2028 OpenMPIRBuilder::InsertPointTy
applyStaticWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,bool NeedsBarrier)2029 OpenMPIRBuilder::applyStaticWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
2030                                           InsertPointTy AllocaIP,
2031                                           bool NeedsBarrier) {
2032   assert(CLI->isValid() && "Requires a valid canonical loop");
2033   assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
2034          "Require dedicated allocate IP");
2035 
2036   // Set up the source location value for OpenMP runtime.
2037   Builder.restoreIP(CLI->getPreheaderIP());
2038   Builder.SetCurrentDebugLocation(DL);
2039 
2040   uint32_t SrcLocStrSize;
2041   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
2042   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2043 
2044   // Declare useful OpenMP runtime functions.
2045   Value *IV = CLI->getIndVar();
2046   Type *IVTy = IV->getType();
2047   FunctionCallee StaticInit = getKmpcForStaticInitForType(IVTy, M, *this);
2048   FunctionCallee StaticFini =
2049       getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
2050 
2051   // Allocate space for computed loop bounds as expected by the "init" function.
2052   Builder.restoreIP(AllocaIP);
2053   Type *I32Type = Type::getInt32Ty(M.getContext());
2054   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
2055   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
2056   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
2057   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
2058 
2059   // At the end of the preheader, prepare for calling the "init" function by
2060   // storing the current loop bounds into the allocated space. A canonical loop
2061   // always iterates from 0 to trip-count with step 1. Note that "init" expects
2062   // and produces an inclusive upper bound.
2063   Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
2064   Constant *Zero = ConstantInt::get(IVTy, 0);
2065   Constant *One = ConstantInt::get(IVTy, 1);
2066   Builder.CreateStore(Zero, PLowerBound);
2067   Value *UpperBound = Builder.CreateSub(CLI->getTripCount(), One);
2068   Builder.CreateStore(UpperBound, PUpperBound);
2069   Builder.CreateStore(One, PStride);
2070 
2071   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
2072 
2073   Constant *SchedulingType = ConstantInt::get(
2074       I32Type, static_cast<int>(OMPScheduleType::UnorderedStatic));
2075 
2076   // Call the "init" function and update the trip count of the loop with the
2077   // value it produced.
2078   Builder.CreateCall(StaticInit,
2079                      {SrcLoc, ThreadNum, SchedulingType, PLastIter, PLowerBound,
2080                       PUpperBound, PStride, One, Zero});
2081   Value *LowerBound = Builder.CreateLoad(IVTy, PLowerBound);
2082   Value *InclusiveUpperBound = Builder.CreateLoad(IVTy, PUpperBound);
2083   Value *TripCountMinusOne = Builder.CreateSub(InclusiveUpperBound, LowerBound);
2084   Value *TripCount = Builder.CreateAdd(TripCountMinusOne, One);
2085   CLI->setTripCount(TripCount);
2086 
2087   // Update all uses of the induction variable except the one in the condition
2088   // block that compares it with the actual upper bound, and the increment in
2089   // the latch block.
2090 
2091   CLI->mapIndVar([&](Instruction *OldIV) -> Value * {
2092     Builder.SetInsertPoint(CLI->getBody(),
2093                            CLI->getBody()->getFirstInsertionPt());
2094     Builder.SetCurrentDebugLocation(DL);
2095     return Builder.CreateAdd(OldIV, LowerBound);
2096   });
2097 
2098   // In the "exit" block, call the "fini" function.
2099   Builder.SetInsertPoint(CLI->getExit(),
2100                          CLI->getExit()->getTerminator()->getIterator());
2101   Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
2102 
2103   // Add the barrier if requested.
2104   if (NeedsBarrier)
2105     createBarrier(LocationDescription(Builder.saveIP(), DL),
2106                   omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
2107                   /* CheckCancelFlag */ false);
2108 
2109   InsertPointTy AfterIP = CLI->getAfterIP();
2110   CLI->invalidate();
2111 
2112   return AfterIP;
2113 }
2114 
applyStaticChunkedWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,bool NeedsBarrier,Value * ChunkSize)2115 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
2116     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
2117     bool NeedsBarrier, Value *ChunkSize) {
2118   assert(CLI->isValid() && "Requires a valid canonical loop");
2119   assert(ChunkSize && "Chunk size is required");
2120 
2121   LLVMContext &Ctx = CLI->getFunction()->getContext();
2122   Value *IV = CLI->getIndVar();
2123   Value *OrigTripCount = CLI->getTripCount();
2124   Type *IVTy = IV->getType();
2125   assert(IVTy->getIntegerBitWidth() <= 64 &&
2126          "Max supported tripcount bitwidth is 64 bits");
2127   Type *InternalIVTy = IVTy->getIntegerBitWidth() <= 32 ? Type::getInt32Ty(Ctx)
2128                                                         : Type::getInt64Ty(Ctx);
2129   Type *I32Type = Type::getInt32Ty(M.getContext());
2130   Constant *Zero = ConstantInt::get(InternalIVTy, 0);
2131   Constant *One = ConstantInt::get(InternalIVTy, 1);
2132 
2133   // Declare useful OpenMP runtime functions.
2134   FunctionCallee StaticInit =
2135       getKmpcForStaticInitForType(InternalIVTy, M, *this);
2136   FunctionCallee StaticFini =
2137       getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
2138 
2139   // Allocate space for computed loop bounds as expected by the "init" function.
2140   Builder.restoreIP(AllocaIP);
2141   Builder.SetCurrentDebugLocation(DL);
2142   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
2143   Value *PLowerBound =
2144       Builder.CreateAlloca(InternalIVTy, nullptr, "p.lowerbound");
2145   Value *PUpperBound =
2146       Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
2147   Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
2148 
2149   // Set up the source location value for the OpenMP runtime.
2150   Builder.restoreIP(CLI->getPreheaderIP());
2151   Builder.SetCurrentDebugLocation(DL);
2152 
2153   // TODO: Detect overflow in ubsan or max-out with current tripcount.
2154   Value *CastedChunkSize =
2155       Builder.CreateZExtOrTrunc(ChunkSize, InternalIVTy, "chunksize");
2156   Value *CastedTripCount =
2157       Builder.CreateZExt(OrigTripCount, InternalIVTy, "tripcount");
2158 
2159   Constant *SchedulingType = ConstantInt::get(
2160       I32Type, static_cast<int>(OMPScheduleType::UnorderedStaticChunked));
2161   Builder.CreateStore(Zero, PLowerBound);
2162   Value *OrigUpperBound = Builder.CreateSub(CastedTripCount, One);
2163   Builder.CreateStore(OrigUpperBound, PUpperBound);
2164   Builder.CreateStore(One, PStride);
2165 
2166   // Call the "init" function and update the trip count of the loop with the
2167   // value it produced.
2168   uint32_t SrcLocStrSize;
2169   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
2170   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2171   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
2172   Builder.CreateCall(StaticInit,
2173                      {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
2174                       /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
2175                       /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
2176                       /*pstride=*/PStride, /*incr=*/One,
2177                       /*chunk=*/CastedChunkSize});
2178 
2179   // Load values written by the "init" function.
2180   Value *FirstChunkStart =
2181       Builder.CreateLoad(InternalIVTy, PLowerBound, "omp_firstchunk.lb");
2182   Value *FirstChunkStop =
2183       Builder.CreateLoad(InternalIVTy, PUpperBound, "omp_firstchunk.ub");
2184   Value *FirstChunkEnd = Builder.CreateAdd(FirstChunkStop, One);
2185   Value *ChunkRange =
2186       Builder.CreateSub(FirstChunkEnd, FirstChunkStart, "omp_chunk.range");
2187   Value *NextChunkStride =
2188       Builder.CreateLoad(InternalIVTy, PStride, "omp_dispatch.stride");
2189 
2190   // Create outer "dispatch" loop for enumerating the chunks.
2191   BasicBlock *DispatchEnter = splitBB(Builder, true);
2192   Value *DispatchCounter;
2193   CanonicalLoopInfo *DispatchCLI = createCanonicalLoop(
2194       {Builder.saveIP(), DL},
2195       [&](InsertPointTy BodyIP, Value *Counter) { DispatchCounter = Counter; },
2196       FirstChunkStart, CastedTripCount, NextChunkStride,
2197       /*IsSigned=*/false, /*InclusiveStop=*/false, /*ComputeIP=*/{},
2198       "dispatch");
2199 
2200   // Remember the BasicBlocks of the dispatch loop we need, then invalidate to
2201   // not have to preserve the canonical invariant.
2202   BasicBlock *DispatchBody = DispatchCLI->getBody();
2203   BasicBlock *DispatchLatch = DispatchCLI->getLatch();
2204   BasicBlock *DispatchExit = DispatchCLI->getExit();
2205   BasicBlock *DispatchAfter = DispatchCLI->getAfter();
2206   DispatchCLI->invalidate();
2207 
2208   // Rewire the original loop to become the chunk loop inside the dispatch loop.
2209   redirectTo(DispatchAfter, CLI->getAfter(), DL);
2210   redirectTo(CLI->getExit(), DispatchLatch, DL);
2211   redirectTo(DispatchBody, DispatchEnter, DL);
2212 
2213   // Prepare the prolog of the chunk loop.
2214   Builder.restoreIP(CLI->getPreheaderIP());
2215   Builder.SetCurrentDebugLocation(DL);
2216 
2217   // Compute the number of iterations of the chunk loop.
2218   Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
2219   Value *ChunkEnd = Builder.CreateAdd(DispatchCounter, ChunkRange);
2220   Value *IsLastChunk =
2221       Builder.CreateICmpUGE(ChunkEnd, CastedTripCount, "omp_chunk.is_last");
2222   Value *CountUntilOrigTripCount =
2223       Builder.CreateSub(CastedTripCount, DispatchCounter);
2224   Value *ChunkTripCount = Builder.CreateSelect(
2225       IsLastChunk, CountUntilOrigTripCount, ChunkRange, "omp_chunk.tripcount");
2226   Value *BackcastedChunkTC =
2227       Builder.CreateTrunc(ChunkTripCount, IVTy, "omp_chunk.tripcount.trunc");
2228   CLI->setTripCount(BackcastedChunkTC);
2229 
2230   // Update all uses of the induction variable except the one in the condition
2231   // block that compares it with the actual upper bound, and the increment in
2232   // the latch block.
2233   Value *BackcastedDispatchCounter =
2234       Builder.CreateTrunc(DispatchCounter, IVTy, "omp_dispatch.iv.trunc");
2235   CLI->mapIndVar([&](Instruction *) -> Value * {
2236     Builder.restoreIP(CLI->getBodyIP());
2237     return Builder.CreateAdd(IV, BackcastedDispatchCounter);
2238   });
2239 
2240   // In the "exit" block, call the "fini" function.
2241   Builder.SetInsertPoint(DispatchExit, DispatchExit->getFirstInsertionPt());
2242   Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
2243 
2244   // Add the barrier if requested.
2245   if (NeedsBarrier)
2246     createBarrier(LocationDescription(Builder.saveIP(), DL), OMPD_for,
2247                   /*ForceSimpleCall=*/false, /*CheckCancelFlag=*/false);
2248 
2249 #ifndef NDEBUG
2250   // Even though we currently do not support applying additional methods to it,
2251   // the chunk loop should remain a canonical loop.
2252   CLI->assertOK();
2253 #endif
2254 
2255   return {DispatchAfter, DispatchAfter->getFirstInsertionPt()};
2256 }
2257 
applyWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,bool NeedsBarrier,llvm::omp::ScheduleKind SchedKind,llvm::Value * ChunkSize,bool HasSimdModifier,bool HasMonotonicModifier,bool HasNonmonotonicModifier,bool HasOrderedClause)2258 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoop(
2259     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
2260     bool NeedsBarrier, llvm::omp::ScheduleKind SchedKind,
2261     llvm::Value *ChunkSize, bool HasSimdModifier, bool HasMonotonicModifier,
2262     bool HasNonmonotonicModifier, bool HasOrderedClause) {
2263   OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
2264       SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier,
2265       HasNonmonotonicModifier, HasOrderedClause);
2266 
2267   bool IsOrdered = (EffectiveScheduleType & OMPScheduleType::ModifierOrdered) ==
2268                    OMPScheduleType::ModifierOrdered;
2269   switch (EffectiveScheduleType & ~OMPScheduleType::ModifierMask) {
2270   case OMPScheduleType::BaseStatic:
2271     assert(!ChunkSize && "No chunk size with static-chunked schedule");
2272     if (IsOrdered)
2273       return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
2274                                        NeedsBarrier, ChunkSize);
2275     // FIXME: Monotonicity ignored?
2276     return applyStaticWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier);
2277 
2278   case OMPScheduleType::BaseStaticChunked:
2279     if (IsOrdered)
2280       return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
2281                                        NeedsBarrier, ChunkSize);
2282     // FIXME: Monotonicity ignored?
2283     return applyStaticChunkedWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier,
2284                                            ChunkSize);
2285 
2286   case OMPScheduleType::BaseRuntime:
2287   case OMPScheduleType::BaseAuto:
2288   case OMPScheduleType::BaseGreedy:
2289   case OMPScheduleType::BaseBalanced:
2290   case OMPScheduleType::BaseSteal:
2291   case OMPScheduleType::BaseGuidedSimd:
2292   case OMPScheduleType::BaseRuntimeSimd:
2293     assert(!ChunkSize &&
2294            "schedule type does not support user-defined chunk sizes");
2295     LLVM_FALLTHROUGH;
2296   case OMPScheduleType::BaseDynamicChunked:
2297   case OMPScheduleType::BaseGuidedChunked:
2298   case OMPScheduleType::BaseGuidedIterativeChunked:
2299   case OMPScheduleType::BaseGuidedAnalyticalChunked:
2300   case OMPScheduleType::BaseStaticBalancedChunked:
2301     return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
2302                                      NeedsBarrier, ChunkSize);
2303 
2304   default:
2305     llvm_unreachable("Unknown/unimplemented schedule kind");
2306   }
2307 }
2308 
2309 /// Returns an LLVM function to call for initializing loop bounds using OpenMP
2310 /// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
2311 /// the runtime. Always interpret integers as unsigned similarly to
2312 /// CanonicalLoopInfo.
2313 static FunctionCallee
getKmpcForDynamicInitForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)2314 getKmpcForDynamicInitForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
2315   unsigned Bitwidth = Ty->getIntegerBitWidth();
2316   if (Bitwidth == 32)
2317     return OMPBuilder.getOrCreateRuntimeFunction(
2318         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_4u);
2319   if (Bitwidth == 64)
2320     return OMPBuilder.getOrCreateRuntimeFunction(
2321         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_8u);
2322   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
2323 }
2324 
2325 /// Returns an LLVM function to call for updating the next loop using OpenMP
2326 /// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
2327 /// the runtime. Always interpret integers as unsigned similarly to
2328 /// CanonicalLoopInfo.
2329 static FunctionCallee
getKmpcForDynamicNextForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)2330 getKmpcForDynamicNextForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
2331   unsigned Bitwidth = Ty->getIntegerBitWidth();
2332   if (Bitwidth == 32)
2333     return OMPBuilder.getOrCreateRuntimeFunction(
2334         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_4u);
2335   if (Bitwidth == 64)
2336     return OMPBuilder.getOrCreateRuntimeFunction(
2337         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_8u);
2338   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
2339 }
2340 
2341 /// Returns an LLVM function to call for finalizing the dynamic loop using
2342 /// depending on `type`. Only i32 and i64 are supported by the runtime. Always
2343 /// interpret integers as unsigned similarly to CanonicalLoopInfo.
2344 static FunctionCallee
getKmpcForDynamicFiniForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)2345 getKmpcForDynamicFiniForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
2346   unsigned Bitwidth = Ty->getIntegerBitWidth();
2347   if (Bitwidth == 32)
2348     return OMPBuilder.getOrCreateRuntimeFunction(
2349         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_4u);
2350   if (Bitwidth == 64)
2351     return OMPBuilder.getOrCreateRuntimeFunction(
2352         M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_8u);
2353   llvm_unreachable("unknown OpenMP loop iterator bitwidth");
2354 }
2355 
applyDynamicWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,OMPScheduleType SchedType,bool NeedsBarrier,Value * Chunk)2356 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyDynamicWorkshareLoop(
2357     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
2358     OMPScheduleType SchedType, bool NeedsBarrier, Value *Chunk) {
2359   assert(CLI->isValid() && "Requires a valid canonical loop");
2360   assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
2361          "Require dedicated allocate IP");
2362   assert(isValidWorkshareLoopScheduleType(SchedType) &&
2363          "Require valid schedule type");
2364 
2365   bool Ordered = (SchedType & OMPScheduleType::ModifierOrdered) ==
2366                  OMPScheduleType::ModifierOrdered;
2367 
2368   // Set up the source location value for OpenMP runtime.
2369   Builder.SetCurrentDebugLocation(DL);
2370 
2371   uint32_t SrcLocStrSize;
2372   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
2373   Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2374 
2375   // Declare useful OpenMP runtime functions.
2376   Value *IV = CLI->getIndVar();
2377   Type *IVTy = IV->getType();
2378   FunctionCallee DynamicInit = getKmpcForDynamicInitForType(IVTy, M, *this);
2379   FunctionCallee DynamicNext = getKmpcForDynamicNextForType(IVTy, M, *this);
2380 
2381   // Allocate space for computed loop bounds as expected by the "init" function.
2382   Builder.restoreIP(AllocaIP);
2383   Type *I32Type = Type::getInt32Ty(M.getContext());
2384   Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
2385   Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
2386   Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
2387   Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
2388 
2389   // At the end of the preheader, prepare for calling the "init" function by
2390   // storing the current loop bounds into the allocated space. A canonical loop
2391   // always iterates from 0 to trip-count with step 1. Note that "init" expects
2392   // and produces an inclusive upper bound.
2393   BasicBlock *PreHeader = CLI->getPreheader();
2394   Builder.SetInsertPoint(PreHeader->getTerminator());
2395   Constant *One = ConstantInt::get(IVTy, 1);
2396   Builder.CreateStore(One, PLowerBound);
2397   Value *UpperBound = CLI->getTripCount();
2398   Builder.CreateStore(UpperBound, PUpperBound);
2399   Builder.CreateStore(One, PStride);
2400 
2401   BasicBlock *Header = CLI->getHeader();
2402   BasicBlock *Exit = CLI->getExit();
2403   BasicBlock *Cond = CLI->getCond();
2404   BasicBlock *Latch = CLI->getLatch();
2405   InsertPointTy AfterIP = CLI->getAfterIP();
2406 
2407   // The CLI will be "broken" in the code below, as the loop is no longer
2408   // a valid canonical loop.
2409 
2410   if (!Chunk)
2411     Chunk = One;
2412 
2413   Value *ThreadNum = getOrCreateThreadID(SrcLoc);
2414 
2415   Constant *SchedulingType =
2416       ConstantInt::get(I32Type, static_cast<int>(SchedType));
2417 
2418   // Call the "init" function.
2419   Builder.CreateCall(DynamicInit,
2420                      {SrcLoc, ThreadNum, SchedulingType, /* LowerBound */ One,
2421                       UpperBound, /* step */ One, Chunk});
2422 
2423   // An outer loop around the existing one.
2424   BasicBlock *OuterCond = BasicBlock::Create(
2425       PreHeader->getContext(), Twine(PreHeader->getName()) + ".outer.cond",
2426       PreHeader->getParent());
2427   // This needs to be 32-bit always, so can't use the IVTy Zero above.
2428   Builder.SetInsertPoint(OuterCond, OuterCond->getFirstInsertionPt());
2429   Value *Res =
2430       Builder.CreateCall(DynamicNext, {SrcLoc, ThreadNum, PLastIter,
2431                                        PLowerBound, PUpperBound, PStride});
2432   Constant *Zero32 = ConstantInt::get(I32Type, 0);
2433   Value *MoreWork = Builder.CreateCmp(CmpInst::ICMP_NE, Res, Zero32);
2434   Value *LowerBound =
2435       Builder.CreateSub(Builder.CreateLoad(IVTy, PLowerBound), One, "lb");
2436   Builder.CreateCondBr(MoreWork, Header, Exit);
2437 
2438   // Change PHI-node in loop header to use outer cond rather than preheader,
2439   // and set IV to the LowerBound.
2440   Instruction *Phi = &Header->front();
2441   auto *PI = cast<PHINode>(Phi);
2442   PI->setIncomingBlock(0, OuterCond);
2443   PI->setIncomingValue(0, LowerBound);
2444 
2445   // Then set the pre-header to jump to the OuterCond
2446   Instruction *Term = PreHeader->getTerminator();
2447   auto *Br = cast<BranchInst>(Term);
2448   Br->setSuccessor(0, OuterCond);
2449 
2450   // Modify the inner condition:
2451   // * Use the UpperBound returned from the DynamicNext call.
2452   // * jump to the loop outer loop when done with one of the inner loops.
2453   Builder.SetInsertPoint(Cond, Cond->getFirstInsertionPt());
2454   UpperBound = Builder.CreateLoad(IVTy, PUpperBound, "ub");
2455   Instruction *Comp = &*Builder.GetInsertPoint();
2456   auto *CI = cast<CmpInst>(Comp);
2457   CI->setOperand(1, UpperBound);
2458   // Redirect the inner exit to branch to outer condition.
2459   Instruction *Branch = &Cond->back();
2460   auto *BI = cast<BranchInst>(Branch);
2461   assert(BI->getSuccessor(1) == Exit);
2462   BI->setSuccessor(1, OuterCond);
2463 
2464   // Call the "fini" function if "ordered" is present in wsloop directive.
2465   if (Ordered) {
2466     Builder.SetInsertPoint(&Latch->back());
2467     FunctionCallee DynamicFini = getKmpcForDynamicFiniForType(IVTy, M, *this);
2468     Builder.CreateCall(DynamicFini, {SrcLoc, ThreadNum});
2469   }
2470 
2471   // Add the barrier if requested.
2472   if (NeedsBarrier) {
2473     Builder.SetInsertPoint(&Exit->back());
2474     createBarrier(LocationDescription(Builder.saveIP(), DL),
2475                   omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
2476                   /* CheckCancelFlag */ false);
2477   }
2478 
2479   CLI->invalidate();
2480   return AfterIP;
2481 }
2482 
2483 /// Redirect all edges that branch to \p OldTarget to \p NewTarget. That is,
2484 /// after this \p OldTarget will be orphaned.
redirectAllPredecessorsTo(BasicBlock * OldTarget,BasicBlock * NewTarget,DebugLoc DL)2485 static void redirectAllPredecessorsTo(BasicBlock *OldTarget,
2486                                       BasicBlock *NewTarget, DebugLoc DL) {
2487   for (BasicBlock *Pred : make_early_inc_range(predecessors(OldTarget)))
2488     redirectTo(Pred, NewTarget, DL);
2489 }
2490 
2491 /// Determine which blocks in \p BBs are reachable from outside and remove the
2492 /// ones that are not reachable from the function.
removeUnusedBlocksFromParent(ArrayRef<BasicBlock * > BBs)2493 static void removeUnusedBlocksFromParent(ArrayRef<BasicBlock *> BBs) {
2494   SmallPtrSet<BasicBlock *, 6> BBsToErase{BBs.begin(), BBs.end()};
2495   auto HasRemainingUses = [&BBsToErase](BasicBlock *BB) {
2496     for (Use &U : BB->uses()) {
2497       auto *UseInst = dyn_cast<Instruction>(U.getUser());
2498       if (!UseInst)
2499         continue;
2500       if (BBsToErase.count(UseInst->getParent()))
2501         continue;
2502       return true;
2503     }
2504     return false;
2505   };
2506 
2507   while (true) {
2508     bool Changed = false;
2509     for (BasicBlock *BB : make_early_inc_range(BBsToErase)) {
2510       if (HasRemainingUses(BB)) {
2511         BBsToErase.erase(BB);
2512         Changed = true;
2513       }
2514     }
2515     if (!Changed)
2516       break;
2517   }
2518 
2519   SmallVector<BasicBlock *, 7> BBVec(BBsToErase.begin(), BBsToErase.end());
2520   DeleteDeadBlocks(BBVec);
2521 }
2522 
2523 CanonicalLoopInfo *
collapseLoops(DebugLoc DL,ArrayRef<CanonicalLoopInfo * > Loops,InsertPointTy ComputeIP)2524 OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
2525                                InsertPointTy ComputeIP) {
2526   assert(Loops.size() >= 1 && "At least one loop required");
2527   size_t NumLoops = Loops.size();
2528 
2529   // Nothing to do if there is already just one loop.
2530   if (NumLoops == 1)
2531     return Loops.front();
2532 
2533   CanonicalLoopInfo *Outermost = Loops.front();
2534   CanonicalLoopInfo *Innermost = Loops.back();
2535   BasicBlock *OrigPreheader = Outermost->getPreheader();
2536   BasicBlock *OrigAfter = Outermost->getAfter();
2537   Function *F = OrigPreheader->getParent();
2538 
2539   // Loop control blocks that may become orphaned later.
2540   SmallVector<BasicBlock *, 12> OldControlBBs;
2541   OldControlBBs.reserve(6 * Loops.size());
2542   for (CanonicalLoopInfo *Loop : Loops)
2543     Loop->collectControlBlocks(OldControlBBs);
2544 
2545   // Setup the IRBuilder for inserting the trip count computation.
2546   Builder.SetCurrentDebugLocation(DL);
2547   if (ComputeIP.isSet())
2548     Builder.restoreIP(ComputeIP);
2549   else
2550     Builder.restoreIP(Outermost->getPreheaderIP());
2551 
2552   // Derive the collapsed' loop trip count.
2553   // TODO: Find common/largest indvar type.
2554   Value *CollapsedTripCount = nullptr;
2555   for (CanonicalLoopInfo *L : Loops) {
2556     assert(L->isValid() &&
2557            "All loops to collapse must be valid canonical loops");
2558     Value *OrigTripCount = L->getTripCount();
2559     if (!CollapsedTripCount) {
2560       CollapsedTripCount = OrigTripCount;
2561       continue;
2562     }
2563 
2564     // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
2565     CollapsedTripCount = Builder.CreateMul(CollapsedTripCount, OrigTripCount,
2566                                            {}, /*HasNUW=*/true);
2567   }
2568 
2569   // Create the collapsed loop control flow.
2570   CanonicalLoopInfo *Result =
2571       createLoopSkeleton(DL, CollapsedTripCount, F,
2572                          OrigPreheader->getNextNode(), OrigAfter, "collapsed");
2573 
2574   // Build the collapsed loop body code.
2575   // Start with deriving the input loop induction variables from the collapsed
2576   // one, using a divmod scheme. To preserve the original loops' order, the
2577   // innermost loop use the least significant bits.
2578   Builder.restoreIP(Result->getBodyIP());
2579 
2580   Value *Leftover = Result->getIndVar();
2581   SmallVector<Value *> NewIndVars;
2582   NewIndVars.resize(NumLoops);
2583   for (int i = NumLoops - 1; i >= 1; --i) {
2584     Value *OrigTripCount = Loops[i]->getTripCount();
2585 
2586     Value *NewIndVar = Builder.CreateURem(Leftover, OrigTripCount);
2587     NewIndVars[i] = NewIndVar;
2588 
2589     Leftover = Builder.CreateUDiv(Leftover, OrigTripCount);
2590   }
2591   // Outermost loop gets all the remaining bits.
2592   NewIndVars[0] = Leftover;
2593 
2594   // Construct the loop body control flow.
2595   // We progressively construct the branch structure following in direction of
2596   // the control flow, from the leading in-between code, the loop nest body, the
2597   // trailing in-between code, and rejoining the collapsed loop's latch.
2598   // ContinueBlock and ContinuePred keep track of the source(s) of next edge. If
2599   // the ContinueBlock is set, continue with that block. If ContinuePred, use
2600   // its predecessors as sources.
2601   BasicBlock *ContinueBlock = Result->getBody();
2602   BasicBlock *ContinuePred = nullptr;
2603   auto ContinueWith = [&ContinueBlock, &ContinuePred, DL](BasicBlock *Dest,
2604                                                           BasicBlock *NextSrc) {
2605     if (ContinueBlock)
2606       redirectTo(ContinueBlock, Dest, DL);
2607     else
2608       redirectAllPredecessorsTo(ContinuePred, Dest, DL);
2609 
2610     ContinueBlock = nullptr;
2611     ContinuePred = NextSrc;
2612   };
2613 
2614   // The code before the nested loop of each level.
2615   // Because we are sinking it into the nest, it will be executed more often
2616   // that the original loop. More sophisticated schemes could keep track of what
2617   // the in-between code is and instantiate it only once per thread.
2618   for (size_t i = 0; i < NumLoops - 1; ++i)
2619     ContinueWith(Loops[i]->getBody(), Loops[i + 1]->getHeader());
2620 
2621   // Connect the loop nest body.
2622   ContinueWith(Innermost->getBody(), Innermost->getLatch());
2623 
2624   // The code after the nested loop at each level.
2625   for (size_t i = NumLoops - 1; i > 0; --i)
2626     ContinueWith(Loops[i]->getAfter(), Loops[i - 1]->getLatch());
2627 
2628   // Connect the finished loop to the collapsed loop latch.
2629   ContinueWith(Result->getLatch(), nullptr);
2630 
2631   // Replace the input loops with the new collapsed loop.
2632   redirectTo(Outermost->getPreheader(), Result->getPreheader(), DL);
2633   redirectTo(Result->getAfter(), Outermost->getAfter(), DL);
2634 
2635   // Replace the input loop indvars with the derived ones.
2636   for (size_t i = 0; i < NumLoops; ++i)
2637     Loops[i]->getIndVar()->replaceAllUsesWith(NewIndVars[i]);
2638 
2639   // Remove unused parts of the input loops.
2640   removeUnusedBlocksFromParent(OldControlBBs);
2641 
2642   for (CanonicalLoopInfo *L : Loops)
2643     L->invalidate();
2644 
2645 #ifndef NDEBUG
2646   Result->assertOK();
2647 #endif
2648   return Result;
2649 }
2650 
2651 std::vector<CanonicalLoopInfo *>
tileLoops(DebugLoc DL,ArrayRef<CanonicalLoopInfo * > Loops,ArrayRef<Value * > TileSizes)2652 OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
2653                            ArrayRef<Value *> TileSizes) {
2654   assert(TileSizes.size() == Loops.size() &&
2655          "Must pass as many tile sizes as there are loops");
2656   int NumLoops = Loops.size();
2657   assert(NumLoops >= 1 && "At least one loop to tile required");
2658 
2659   CanonicalLoopInfo *OutermostLoop = Loops.front();
2660   CanonicalLoopInfo *InnermostLoop = Loops.back();
2661   Function *F = OutermostLoop->getBody()->getParent();
2662   BasicBlock *InnerEnter = InnermostLoop->getBody();
2663   BasicBlock *InnerLatch = InnermostLoop->getLatch();
2664 
2665   // Loop control blocks that may become orphaned later.
2666   SmallVector<BasicBlock *, 12> OldControlBBs;
2667   OldControlBBs.reserve(6 * Loops.size());
2668   for (CanonicalLoopInfo *Loop : Loops)
2669     Loop->collectControlBlocks(OldControlBBs);
2670 
2671   // Collect original trip counts and induction variable to be accessible by
2672   // index. Also, the structure of the original loops is not preserved during
2673   // the construction of the tiled loops, so do it before we scavenge the BBs of
2674   // any original CanonicalLoopInfo.
2675   SmallVector<Value *, 4> OrigTripCounts, OrigIndVars;
2676   for (CanonicalLoopInfo *L : Loops) {
2677     assert(L->isValid() && "All input loops must be valid canonical loops");
2678     OrigTripCounts.push_back(L->getTripCount());
2679     OrigIndVars.push_back(L->getIndVar());
2680   }
2681 
2682   // Collect the code between loop headers. These may contain SSA definitions
2683   // that are used in the loop nest body. To be usable with in the innermost
2684   // body, these BasicBlocks will be sunk into the loop nest body. That is,
2685   // these instructions may be executed more often than before the tiling.
2686   // TODO: It would be sufficient to only sink them into body of the
2687   // corresponding tile loop.
2688   SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4> InbetweenCode;
2689   for (int i = 0; i < NumLoops - 1; ++i) {
2690     CanonicalLoopInfo *Surrounding = Loops[i];
2691     CanonicalLoopInfo *Nested = Loops[i + 1];
2692 
2693     BasicBlock *EnterBB = Surrounding->getBody();
2694     BasicBlock *ExitBB = Nested->getHeader();
2695     InbetweenCode.emplace_back(EnterBB, ExitBB);
2696   }
2697 
2698   // Compute the trip counts of the floor loops.
2699   Builder.SetCurrentDebugLocation(DL);
2700   Builder.restoreIP(OutermostLoop->getPreheaderIP());
2701   SmallVector<Value *, 4> FloorCount, FloorRems;
2702   for (int i = 0; i < NumLoops; ++i) {
2703     Value *TileSize = TileSizes[i];
2704     Value *OrigTripCount = OrigTripCounts[i];
2705     Type *IVType = OrigTripCount->getType();
2706 
2707     Value *FloorTripCount = Builder.CreateUDiv(OrigTripCount, TileSize);
2708     Value *FloorTripRem = Builder.CreateURem(OrigTripCount, TileSize);
2709 
2710     // 0 if tripcount divides the tilesize, 1 otherwise.
2711     // 1 means we need an additional iteration for a partial tile.
2712     //
2713     // Unfortunately we cannot just use the roundup-formula
2714     //   (tripcount + tilesize - 1)/tilesize
2715     // because the summation might overflow. We do not want introduce undefined
2716     // behavior when the untiled loop nest did not.
2717     Value *FloorTripOverflow =
2718         Builder.CreateICmpNE(FloorTripRem, ConstantInt::get(IVType, 0));
2719 
2720     FloorTripOverflow = Builder.CreateZExt(FloorTripOverflow, IVType);
2721     FloorTripCount =
2722         Builder.CreateAdd(FloorTripCount, FloorTripOverflow,
2723                           "omp_floor" + Twine(i) + ".tripcount", true);
2724 
2725     // Remember some values for later use.
2726     FloorCount.push_back(FloorTripCount);
2727     FloorRems.push_back(FloorTripRem);
2728   }
2729 
2730   // Generate the new loop nest, from the outermost to the innermost.
2731   std::vector<CanonicalLoopInfo *> Result;
2732   Result.reserve(NumLoops * 2);
2733 
2734   // The basic block of the surrounding loop that enters the nest generated
2735   // loop.
2736   BasicBlock *Enter = OutermostLoop->getPreheader();
2737 
2738   // The basic block of the surrounding loop where the inner code should
2739   // continue.
2740   BasicBlock *Continue = OutermostLoop->getAfter();
2741 
2742   // Where the next loop basic block should be inserted.
2743   BasicBlock *OutroInsertBefore = InnermostLoop->getExit();
2744 
2745   auto EmbeddNewLoop =
2746       [this, DL, F, InnerEnter, &Enter, &Continue, &OutroInsertBefore](
2747           Value *TripCount, const Twine &Name) -> CanonicalLoopInfo * {
2748     CanonicalLoopInfo *EmbeddedLoop = createLoopSkeleton(
2749         DL, TripCount, F, InnerEnter, OutroInsertBefore, Name);
2750     redirectTo(Enter, EmbeddedLoop->getPreheader(), DL);
2751     redirectTo(EmbeddedLoop->getAfter(), Continue, DL);
2752 
2753     // Setup the position where the next embedded loop connects to this loop.
2754     Enter = EmbeddedLoop->getBody();
2755     Continue = EmbeddedLoop->getLatch();
2756     OutroInsertBefore = EmbeddedLoop->getLatch();
2757     return EmbeddedLoop;
2758   };
2759 
2760   auto EmbeddNewLoops = [&Result, &EmbeddNewLoop](ArrayRef<Value *> TripCounts,
2761                                                   const Twine &NameBase) {
2762     for (auto P : enumerate(TripCounts)) {
2763       CanonicalLoopInfo *EmbeddedLoop =
2764           EmbeddNewLoop(P.value(), NameBase + Twine(P.index()));
2765       Result.push_back(EmbeddedLoop);
2766     }
2767   };
2768 
2769   EmbeddNewLoops(FloorCount, "floor");
2770 
2771   // Within the innermost floor loop, emit the code that computes the tile
2772   // sizes.
2773   Builder.SetInsertPoint(Enter->getTerminator());
2774   SmallVector<Value *, 4> TileCounts;
2775   for (int i = 0; i < NumLoops; ++i) {
2776     CanonicalLoopInfo *FloorLoop = Result[i];
2777     Value *TileSize = TileSizes[i];
2778 
2779     Value *FloorIsEpilogue =
2780         Builder.CreateICmpEQ(FloorLoop->getIndVar(), FloorCount[i]);
2781     Value *TileTripCount =
2782         Builder.CreateSelect(FloorIsEpilogue, FloorRems[i], TileSize);
2783 
2784     TileCounts.push_back(TileTripCount);
2785   }
2786 
2787   // Create the tile loops.
2788   EmbeddNewLoops(TileCounts, "tile");
2789 
2790   // Insert the inbetween code into the body.
2791   BasicBlock *BodyEnter = Enter;
2792   BasicBlock *BodyEntered = nullptr;
2793   for (std::pair<BasicBlock *, BasicBlock *> P : InbetweenCode) {
2794     BasicBlock *EnterBB = P.first;
2795     BasicBlock *ExitBB = P.second;
2796 
2797     if (BodyEnter)
2798       redirectTo(BodyEnter, EnterBB, DL);
2799     else
2800       redirectAllPredecessorsTo(BodyEntered, EnterBB, DL);
2801 
2802     BodyEnter = nullptr;
2803     BodyEntered = ExitBB;
2804   }
2805 
2806   // Append the original loop nest body into the generated loop nest body.
2807   if (BodyEnter)
2808     redirectTo(BodyEnter, InnerEnter, DL);
2809   else
2810     redirectAllPredecessorsTo(BodyEntered, InnerEnter, DL);
2811   redirectAllPredecessorsTo(InnerLatch, Continue, DL);
2812 
2813   // Replace the original induction variable with an induction variable computed
2814   // from the tile and floor induction variables.
2815   Builder.restoreIP(Result.back()->getBodyIP());
2816   for (int i = 0; i < NumLoops; ++i) {
2817     CanonicalLoopInfo *FloorLoop = Result[i];
2818     CanonicalLoopInfo *TileLoop = Result[NumLoops + i];
2819     Value *OrigIndVar = OrigIndVars[i];
2820     Value *Size = TileSizes[i];
2821 
2822     Value *Scale =
2823         Builder.CreateMul(Size, FloorLoop->getIndVar(), {}, /*HasNUW=*/true);
2824     Value *Shift =
2825         Builder.CreateAdd(Scale, TileLoop->getIndVar(), {}, /*HasNUW=*/true);
2826     OrigIndVar->replaceAllUsesWith(Shift);
2827   }
2828 
2829   // Remove unused parts of the original loops.
2830   removeUnusedBlocksFromParent(OldControlBBs);
2831 
2832   for (CanonicalLoopInfo *L : Loops)
2833     L->invalidate();
2834 
2835 #ifndef NDEBUG
2836   for (CanonicalLoopInfo *GenL : Result)
2837     GenL->assertOK();
2838 #endif
2839   return Result;
2840 }
2841 
2842 /// Attach loop metadata \p Properties to the loop described by \p Loop. If the
2843 /// loop already has metadata, the loop properties are appended.
addLoopMetadata(CanonicalLoopInfo * Loop,ArrayRef<Metadata * > Properties)2844 static void addLoopMetadata(CanonicalLoopInfo *Loop,
2845                             ArrayRef<Metadata *> Properties) {
2846   assert(Loop->isValid() && "Expecting a valid CanonicalLoopInfo");
2847 
2848   // Nothing to do if no property to attach.
2849   if (Properties.empty())
2850     return;
2851 
2852   LLVMContext &Ctx = Loop->getFunction()->getContext();
2853   SmallVector<Metadata *> NewLoopProperties;
2854   NewLoopProperties.push_back(nullptr);
2855 
2856   // If the loop already has metadata, prepend it to the new metadata.
2857   BasicBlock *Latch = Loop->getLatch();
2858   assert(Latch && "A valid CanonicalLoopInfo must have a unique latch");
2859   MDNode *Existing = Latch->getTerminator()->getMetadata(LLVMContext::MD_loop);
2860   if (Existing)
2861     append_range(NewLoopProperties, drop_begin(Existing->operands(), 1));
2862 
2863   append_range(NewLoopProperties, Properties);
2864   MDNode *LoopID = MDNode::getDistinct(Ctx, NewLoopProperties);
2865   LoopID->replaceOperandWith(0, LoopID);
2866 
2867   Latch->getTerminator()->setMetadata(LLVMContext::MD_loop, LoopID);
2868 }
2869 
2870 /// Attach llvm.access.group metadata to the memref instructions of \p Block
addSimdMetadata(BasicBlock * Block,MDNode * AccessGroup,LoopInfo & LI)2871 static void addSimdMetadata(BasicBlock *Block, MDNode *AccessGroup,
2872                             LoopInfo &LI) {
2873   for (Instruction &I : *Block) {
2874     if (I.mayReadOrWriteMemory()) {
2875       // TODO: This instruction may already have access group from
2876       // other pragmas e.g. #pragma clang loop vectorize.  Append
2877       // so that the existing metadata is not overwritten.
2878       I.setMetadata(LLVMContext::MD_access_group, AccessGroup);
2879     }
2880   }
2881 }
2882 
unrollLoopFull(DebugLoc,CanonicalLoopInfo * Loop)2883 void OpenMPIRBuilder::unrollLoopFull(DebugLoc, CanonicalLoopInfo *Loop) {
2884   LLVMContext &Ctx = Builder.getContext();
2885   addLoopMetadata(
2886       Loop, {MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
2887              MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.full"))});
2888 }
2889 
unrollLoopHeuristic(DebugLoc,CanonicalLoopInfo * Loop)2890 void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) {
2891   LLVMContext &Ctx = Builder.getContext();
2892   addLoopMetadata(
2893       Loop, {
2894                 MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
2895             });
2896 }
2897 
applySimd(CanonicalLoopInfo * CanonicalLoop,ConstantInt * Simdlen)2898 void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
2899                                 ConstantInt *Simdlen) {
2900   LLVMContext &Ctx = Builder.getContext();
2901 
2902   Function *F = CanonicalLoop->getFunction();
2903 
2904   FunctionAnalysisManager FAM;
2905   FAM.registerPass([]() { return DominatorTreeAnalysis(); });
2906   FAM.registerPass([]() { return LoopAnalysis(); });
2907   FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
2908 
2909   LoopAnalysis LIA;
2910   LoopInfo &&LI = LIA.run(*F, FAM);
2911 
2912   Loop *L = LI.getLoopFor(CanonicalLoop->getHeader());
2913 
2914   SmallSet<BasicBlock *, 8> Reachable;
2915 
2916   // Get the basic blocks from the loop in which memref instructions
2917   // can be found.
2918   // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
2919   // preferably without running any passes.
2920   for (BasicBlock *Block : L->getBlocks()) {
2921     if (Block == CanonicalLoop->getCond() ||
2922         Block == CanonicalLoop->getHeader())
2923       continue;
2924     Reachable.insert(Block);
2925   }
2926 
2927   // Add access group metadata to memory-access instructions.
2928   MDNode *AccessGroup = MDNode::getDistinct(Ctx, {});
2929   for (BasicBlock *BB : Reachable)
2930     addSimdMetadata(BB, AccessGroup, LI);
2931 
2932   // Use the above access group metadata to create loop level
2933   // metadata, which should be distinct for each loop.
2934   ConstantAsMetadata *BoolConst =
2935       ConstantAsMetadata::get(ConstantInt::getTrue(Type::getInt1Ty(Ctx)));
2936   // TODO:  If the loop has existing parallel access metadata, have
2937   // to combine two lists.
2938   addLoopMetadata(
2939       CanonicalLoop,
2940       {MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.parallel_accesses"),
2941                          AccessGroup}),
2942        MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"),
2943                          BoolConst})});
2944   if (Simdlen != nullptr)
2945     addLoopMetadata(
2946         CanonicalLoop,
2947         MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.width"),
2948                           ConstantAsMetadata::get(Simdlen)}));
2949 }
2950 
2951 /// Create the TargetMachine object to query the backend for optimization
2952 /// preferences.
2953 ///
2954 /// Ideally, this would be passed from the front-end to the OpenMPBuilder, but
2955 /// e.g. Clang does not pass it to its CodeGen layer and creates it only when
2956 /// needed for the LLVM pass pipline. We use some default options to avoid
2957 /// having to pass too many settings from the frontend that probably do not
2958 /// matter.
2959 ///
2960 /// Currently, TargetMachine is only used sometimes by the unrollLoopPartial
2961 /// method. If we are going to use TargetMachine for more purposes, especially
2962 /// those that are sensitive to TargetOptions, RelocModel and CodeModel, it
2963 /// might become be worth requiring front-ends to pass on their TargetMachine,
2964 /// or at least cache it between methods. Note that while fontends such as Clang
2965 /// have just a single main TargetMachine per translation unit, "target-cpu" and
2966 /// "target-features" that determine the TargetMachine are per-function and can
2967 /// be overrided using __attribute__((target("OPTIONS"))).
2968 static std::unique_ptr<TargetMachine>
createTargetMachine(Function * F,CodeGenOpt::Level OptLevel)2969 createTargetMachine(Function *F, CodeGenOpt::Level OptLevel) {
2970   Module *M = F->getParent();
2971 
2972   StringRef CPU = F->getFnAttribute("target-cpu").getValueAsString();
2973   StringRef Features = F->getFnAttribute("target-features").getValueAsString();
2974   const std::string &Triple = M->getTargetTriple();
2975 
2976   std::string Error;
2977   const llvm::Target *TheTarget = TargetRegistry::lookupTarget(Triple, Error);
2978   if (!TheTarget)
2979     return {};
2980 
2981   llvm::TargetOptions Options;
2982   return std::unique_ptr<TargetMachine>(TheTarget->createTargetMachine(
2983       Triple, CPU, Features, Options, /*RelocModel=*/None, /*CodeModel=*/None,
2984       OptLevel));
2985 }
2986 
2987 /// Heuristically determine the best-performant unroll factor for \p CLI. This
2988 /// depends on the target processor. We are re-using the same heuristics as the
2989 /// LoopUnrollPass.
computeHeuristicUnrollFactor(CanonicalLoopInfo * CLI)2990 static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
2991   Function *F = CLI->getFunction();
2992 
2993   // Assume the user requests the most aggressive unrolling, even if the rest of
2994   // the code is optimized using a lower setting.
2995   CodeGenOpt::Level OptLevel = CodeGenOpt::Aggressive;
2996   std::unique_ptr<TargetMachine> TM = createTargetMachine(F, OptLevel);
2997 
2998   FunctionAnalysisManager FAM;
2999   FAM.registerPass([]() { return TargetLibraryAnalysis(); });
3000   FAM.registerPass([]() { return AssumptionAnalysis(); });
3001   FAM.registerPass([]() { return DominatorTreeAnalysis(); });
3002   FAM.registerPass([]() { return LoopAnalysis(); });
3003   FAM.registerPass([]() { return ScalarEvolutionAnalysis(); });
3004   FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
3005   TargetIRAnalysis TIRA;
3006   if (TM)
3007     TIRA = TargetIRAnalysis(
3008         [&](const Function &F) { return TM->getTargetTransformInfo(F); });
3009   FAM.registerPass([&]() { return TIRA; });
3010 
3011   TargetIRAnalysis::Result &&TTI = TIRA.run(*F, FAM);
3012   ScalarEvolutionAnalysis SEA;
3013   ScalarEvolution &&SE = SEA.run(*F, FAM);
3014   DominatorTreeAnalysis DTA;
3015   DominatorTree &&DT = DTA.run(*F, FAM);
3016   LoopAnalysis LIA;
3017   LoopInfo &&LI = LIA.run(*F, FAM);
3018   AssumptionAnalysis ACT;
3019   AssumptionCache &&AC = ACT.run(*F, FAM);
3020   OptimizationRemarkEmitter ORE{F};
3021 
3022   Loop *L = LI.getLoopFor(CLI->getHeader());
3023   assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");
3024 
3025   TargetTransformInfo::UnrollingPreferences UP =
3026       gatherUnrollingPreferences(L, SE, TTI,
3027                                  /*BlockFrequencyInfo=*/nullptr,
3028                                  /*ProfileSummaryInfo=*/nullptr, ORE, OptLevel,
3029                                  /*UserThreshold=*/None,
3030                                  /*UserCount=*/None,
3031                                  /*UserAllowPartial=*/true,
3032                                  /*UserAllowRuntime=*/true,
3033                                  /*UserUpperBound=*/None,
3034                                  /*UserFullUnrollMaxCount=*/None);
3035 
3036   UP.Force = true;
3037 
3038   // Account for additional optimizations taking place before the LoopUnrollPass
3039   // would unroll the loop.
3040   UP.Threshold *= UnrollThresholdFactor;
3041   UP.PartialThreshold *= UnrollThresholdFactor;
3042 
3043   // Use normal unroll factors even if the rest of the code is optimized for
3044   // size.
3045   UP.OptSizeThreshold = UP.Threshold;
3046   UP.PartialOptSizeThreshold = UP.PartialThreshold;
3047 
3048   LLVM_DEBUG(dbgs() << "Unroll heuristic thresholds:\n"
3049                     << "  Threshold=" << UP.Threshold << "\n"
3050                     << "  PartialThreshold=" << UP.PartialThreshold << "\n"
3051                     << "  OptSizeThreshold=" << UP.OptSizeThreshold << "\n"
3052                     << "  PartialOptSizeThreshold="
3053                     << UP.PartialOptSizeThreshold << "\n");
3054 
3055   // Disable peeling.
3056   TargetTransformInfo::PeelingPreferences PP =
3057       gatherPeelingPreferences(L, SE, TTI,
3058                                /*UserAllowPeeling=*/false,
3059                                /*UserAllowProfileBasedPeeling=*/false,
3060                                /*UnrollingSpecficValues=*/false);
3061 
3062   SmallPtrSet<const Value *, 32> EphValues;
3063   CodeMetrics::collectEphemeralValues(L, &AC, EphValues);
3064 
3065   // Assume that reads and writes to stack variables can be eliminated by
3066   // Mem2Reg, SROA or LICM. That is, don't count them towards the loop body's
3067   // size.
3068   for (BasicBlock *BB : L->blocks()) {
3069     for (Instruction &I : *BB) {
3070       Value *Ptr;
3071       if (auto *Load = dyn_cast<LoadInst>(&I)) {
3072         Ptr = Load->getPointerOperand();
3073       } else if (auto *Store = dyn_cast<StoreInst>(&I)) {
3074         Ptr = Store->getPointerOperand();
3075       } else
3076         continue;
3077 
3078       Ptr = Ptr->stripPointerCasts();
3079 
3080       if (auto *Alloca = dyn_cast<AllocaInst>(Ptr)) {
3081         if (Alloca->getParent() == &F->getEntryBlock())
3082           EphValues.insert(&I);
3083       }
3084     }
3085   }
3086 
3087   unsigned NumInlineCandidates;
3088   bool NotDuplicatable;
3089   bool Convergent;
3090   InstructionCost LoopSizeIC =
3091       ApproximateLoopSize(L, NumInlineCandidates, NotDuplicatable, Convergent,
3092                           TTI, EphValues, UP.BEInsns);
3093   LLVM_DEBUG(dbgs() << "Estimated loop size is " << LoopSizeIC << "\n");
3094 
3095   // Loop is not unrollable if the loop contains certain instructions.
3096   if (NotDuplicatable || Convergent || !LoopSizeIC.isValid()) {
3097     LLVM_DEBUG(dbgs() << "Loop not considered unrollable\n");
3098     return 1;
3099   }
3100   unsigned LoopSize = *LoopSizeIC.getValue();
3101 
3102   // TODO: Determine trip count of \p CLI if constant, computeUnrollCount might
3103   // be able to use it.
3104   int TripCount = 0;
3105   int MaxTripCount = 0;
3106   bool MaxOrZero = false;
3107   unsigned TripMultiple = 0;
3108 
3109   bool UseUpperBound = false;
3110   computeUnrollCount(L, TTI, DT, &LI, SE, EphValues, &ORE, TripCount,
3111                      MaxTripCount, MaxOrZero, TripMultiple, LoopSize, UP, PP,
3112                      UseUpperBound);
3113   unsigned Factor = UP.Count;
3114   LLVM_DEBUG(dbgs() << "Suggesting unroll factor of " << Factor << "\n");
3115 
3116   // This function returns 1 to signal to not unroll a loop.
3117   if (Factor == 0)
3118     return 1;
3119   return Factor;
3120 }
3121 
unrollLoopPartial(DebugLoc DL,CanonicalLoopInfo * Loop,int32_t Factor,CanonicalLoopInfo ** UnrolledCLI)3122 void OpenMPIRBuilder::unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop,
3123                                         int32_t Factor,
3124                                         CanonicalLoopInfo **UnrolledCLI) {
3125   assert(Factor >= 0 && "Unroll factor must not be negative");
3126 
3127   Function *F = Loop->getFunction();
3128   LLVMContext &Ctx = F->getContext();
3129 
3130   // If the unrolled loop is not used for another loop-associated directive, it
3131   // is sufficient to add metadata for the LoopUnrollPass.
3132   if (!UnrolledCLI) {
3133     SmallVector<Metadata *, 2> LoopMetadata;
3134     LoopMetadata.push_back(
3135         MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")));
3136 
3137     if (Factor >= 1) {
3138       ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
3139           ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
3140       LoopMetadata.push_back(MDNode::get(
3141           Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst}));
3142     }
3143 
3144     addLoopMetadata(Loop, LoopMetadata);
3145     return;
3146   }
3147 
3148   // Heuristically determine the unroll factor.
3149   if (Factor == 0)
3150     Factor = computeHeuristicUnrollFactor(Loop);
3151 
3152   // No change required with unroll factor 1.
3153   if (Factor == 1) {
3154     *UnrolledCLI = Loop;
3155     return;
3156   }
3157 
3158   assert(Factor >= 2 &&
3159          "unrolling only makes sense with a factor of 2 or larger");
3160 
3161   Type *IndVarTy = Loop->getIndVarType();
3162 
3163   // Apply partial unrolling by tiling the loop by the unroll-factor, then fully
3164   // unroll the inner loop.
3165   Value *FactorVal =
3166       ConstantInt::get(IndVarTy, APInt(IndVarTy->getIntegerBitWidth(), Factor,
3167                                        /*isSigned=*/false));
3168   std::vector<CanonicalLoopInfo *> LoopNest =
3169       tileLoops(DL, {Loop}, {FactorVal});
3170   assert(LoopNest.size() == 2 && "Expect 2 loops after tiling");
3171   *UnrolledCLI = LoopNest[0];
3172   CanonicalLoopInfo *InnerLoop = LoopNest[1];
3173 
3174   // LoopUnrollPass can only fully unroll loops with constant trip count.
3175   // Unroll by the unroll factor with a fallback epilog for the remainder
3176   // iterations if necessary.
3177   ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
3178       ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
3179   addLoopMetadata(
3180       InnerLoop,
3181       {MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
3182        MDNode::get(
3183            Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst})});
3184 
3185 #ifndef NDEBUG
3186   (*UnrolledCLI)->assertOK();
3187 #endif
3188 }
3189 
3190 OpenMPIRBuilder::InsertPointTy
createCopyPrivate(const LocationDescription & Loc,llvm::Value * BufSize,llvm::Value * CpyBuf,llvm::Value * CpyFn,llvm::Value * DidIt)3191 OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
3192                                    llvm::Value *BufSize, llvm::Value *CpyBuf,
3193                                    llvm::Value *CpyFn, llvm::Value *DidIt) {
3194   if (!updateToLocation(Loc))
3195     return Loc.IP;
3196 
3197   uint32_t SrcLocStrSize;
3198   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3199   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3200   Value *ThreadId = getOrCreateThreadID(Ident);
3201 
3202   llvm::Value *DidItLD = Builder.CreateLoad(Builder.getInt32Ty(), DidIt);
3203 
3204   Value *Args[] = {Ident, ThreadId, BufSize, CpyBuf, CpyFn, DidItLD};
3205 
3206   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_copyprivate);
3207   Builder.CreateCall(Fn, Args);
3208 
3209   return Builder.saveIP();
3210 }
3211 
createSingle(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool IsNowait,llvm::Value * DidIt)3212 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSingle(
3213     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
3214     FinalizeCallbackTy FiniCB, bool IsNowait, llvm::Value *DidIt) {
3215 
3216   if (!updateToLocation(Loc))
3217     return Loc.IP;
3218 
3219   // If needed (i.e. not null), initialize `DidIt` with 0
3220   if (DidIt) {
3221     Builder.CreateStore(Builder.getInt32(0), DidIt);
3222   }
3223 
3224   Directive OMPD = Directive::OMPD_single;
3225   uint32_t SrcLocStrSize;
3226   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3227   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3228   Value *ThreadId = getOrCreateThreadID(Ident);
3229   Value *Args[] = {Ident, ThreadId};
3230 
3231   Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_single);
3232   Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
3233 
3234   Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_single);
3235   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
3236 
3237   // generates the following:
3238   // if (__kmpc_single()) {
3239   //		.... single region ...
3240   // 		__kmpc_end_single
3241   // }
3242   // __kmpc_barrier
3243 
3244   EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
3245                        /*Conditional*/ true,
3246                        /*hasFinalize*/ true);
3247   if (!IsNowait)
3248     createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
3249                   omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
3250                   /* CheckCancelFlag */ false);
3251   return Builder.saveIP();
3252 }
3253 
createCritical(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,StringRef CriticalName,Value * HintInst)3254 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCritical(
3255     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
3256     FinalizeCallbackTy FiniCB, StringRef CriticalName, Value *HintInst) {
3257 
3258   if (!updateToLocation(Loc))
3259     return Loc.IP;
3260 
3261   Directive OMPD = Directive::OMPD_critical;
3262   uint32_t SrcLocStrSize;
3263   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3264   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3265   Value *ThreadId = getOrCreateThreadID(Ident);
3266   Value *LockVar = getOMPCriticalRegionLock(CriticalName);
3267   Value *Args[] = {Ident, ThreadId, LockVar};
3268 
3269   SmallVector<llvm::Value *, 4> EnterArgs(std::begin(Args), std::end(Args));
3270   Function *RTFn = nullptr;
3271   if (HintInst) {
3272     // Add Hint to entry Args and create call
3273     EnterArgs.push_back(HintInst);
3274     RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical_with_hint);
3275   } else {
3276     RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical);
3277   }
3278   Instruction *EntryCall = Builder.CreateCall(RTFn, EnterArgs);
3279 
3280   Function *ExitRTLFn =
3281       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_critical);
3282   Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
3283 
3284   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
3285                               /*Conditional*/ false, /*hasFinalize*/ true);
3286 }
3287 
3288 OpenMPIRBuilder::InsertPointTy
createOrderedDepend(const LocationDescription & Loc,InsertPointTy AllocaIP,unsigned NumLoops,ArrayRef<llvm::Value * > StoreValues,const Twine & Name,bool IsDependSource)3289 OpenMPIRBuilder::createOrderedDepend(const LocationDescription &Loc,
3290                                      InsertPointTy AllocaIP, unsigned NumLoops,
3291                                      ArrayRef<llvm::Value *> StoreValues,
3292                                      const Twine &Name, bool IsDependSource) {
3293   for (size_t I = 0; I < StoreValues.size(); I++)
3294     assert(StoreValues[I]->getType()->isIntegerTy(64) &&
3295            "OpenMP runtime requires depend vec with i64 type");
3296 
3297   if (!updateToLocation(Loc))
3298     return Loc.IP;
3299 
3300   // Allocate space for vector and generate alloc instruction.
3301   auto *ArrI64Ty = ArrayType::get(Int64, NumLoops);
3302   Builder.restoreIP(AllocaIP);
3303   AllocaInst *ArgsBase = Builder.CreateAlloca(ArrI64Ty, nullptr, Name);
3304   ArgsBase->setAlignment(Align(8));
3305   Builder.restoreIP(Loc.IP);
3306 
3307   // Store the index value with offset in depend vector.
3308   for (unsigned I = 0; I < NumLoops; ++I) {
3309     Value *DependAddrGEPIter = Builder.CreateInBoundsGEP(
3310         ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(I)});
3311     StoreInst *STInst = Builder.CreateStore(StoreValues[I], DependAddrGEPIter);
3312     STInst->setAlignment(Align(8));
3313   }
3314 
3315   Value *DependBaseAddrGEP = Builder.CreateInBoundsGEP(
3316       ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(0)});
3317 
3318   uint32_t SrcLocStrSize;
3319   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3320   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3321   Value *ThreadId = getOrCreateThreadID(Ident);
3322   Value *Args[] = {Ident, ThreadId, DependBaseAddrGEP};
3323 
3324   Function *RTLFn = nullptr;
3325   if (IsDependSource)
3326     RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_post);
3327   else
3328     RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_wait);
3329   Builder.CreateCall(RTLFn, Args);
3330 
3331   return Builder.saveIP();
3332 }
3333 
createOrderedThreadsSimd(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool IsThreads)3334 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createOrderedThreadsSimd(
3335     const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
3336     FinalizeCallbackTy FiniCB, bool IsThreads) {
3337   if (!updateToLocation(Loc))
3338     return Loc.IP;
3339 
3340   Directive OMPD = Directive::OMPD_ordered;
3341   Instruction *EntryCall = nullptr;
3342   Instruction *ExitCall = nullptr;
3343 
3344   if (IsThreads) {
3345     uint32_t SrcLocStrSize;
3346     Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3347     Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3348     Value *ThreadId = getOrCreateThreadID(Ident);
3349     Value *Args[] = {Ident, ThreadId};
3350 
3351     Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_ordered);
3352     EntryCall = Builder.CreateCall(EntryRTLFn, Args);
3353 
3354     Function *ExitRTLFn =
3355         getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_ordered);
3356     ExitCall = Builder.CreateCall(ExitRTLFn, Args);
3357   }
3358 
3359   return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
3360                               /*Conditional*/ false, /*hasFinalize*/ true);
3361 }
3362 
EmitOMPInlinedRegion(Directive OMPD,Instruction * EntryCall,Instruction * ExitCall,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool Conditional,bool HasFinalize,bool IsCancellable)3363 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::EmitOMPInlinedRegion(
3364     Directive OMPD, Instruction *EntryCall, Instruction *ExitCall,
3365     BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, bool Conditional,
3366     bool HasFinalize, bool IsCancellable) {
3367 
3368   if (HasFinalize)
3369     FinalizationStack.push_back({FiniCB, OMPD, IsCancellable});
3370 
3371   // Create inlined region's entry and body blocks, in preparation
3372   // for conditional creation
3373   BasicBlock *EntryBB = Builder.GetInsertBlock();
3374   Instruction *SplitPos = EntryBB->getTerminator();
3375   if (!isa_and_nonnull<BranchInst>(SplitPos))
3376     SplitPos = new UnreachableInst(Builder.getContext(), EntryBB);
3377   BasicBlock *ExitBB = EntryBB->splitBasicBlock(SplitPos, "omp_region.end");
3378   BasicBlock *FiniBB =
3379       EntryBB->splitBasicBlock(EntryBB->getTerminator(), "omp_region.finalize");
3380 
3381   Builder.SetInsertPoint(EntryBB->getTerminator());
3382   emitCommonDirectiveEntry(OMPD, EntryCall, ExitBB, Conditional);
3383 
3384   // generate body
3385   BodyGenCB(/* AllocaIP */ InsertPointTy(),
3386             /* CodeGenIP */ Builder.saveIP());
3387 
3388   // emit exit call and do any needed finalization.
3389   auto FinIP = InsertPointTy(FiniBB, FiniBB->getFirstInsertionPt());
3390   assert(FiniBB->getTerminator()->getNumSuccessors() == 1 &&
3391          FiniBB->getTerminator()->getSuccessor(0) == ExitBB &&
3392          "Unexpected control flow graph state!!");
3393   emitCommonDirectiveExit(OMPD, FinIP, ExitCall, HasFinalize);
3394   assert(FiniBB->getUniquePredecessor()->getUniqueSuccessor() == FiniBB &&
3395          "Unexpected Control Flow State!");
3396   MergeBlockIntoPredecessor(FiniBB);
3397 
3398   // If we are skipping the region of a non conditional, remove the exit
3399   // block, and clear the builder's insertion point.
3400   assert(SplitPos->getParent() == ExitBB &&
3401          "Unexpected Insertion point location!");
3402   auto merged = MergeBlockIntoPredecessor(ExitBB);
3403   BasicBlock *ExitPredBB = SplitPos->getParent();
3404   auto InsertBB = merged ? ExitPredBB : ExitBB;
3405   if (!isa_and_nonnull<BranchInst>(SplitPos))
3406     SplitPos->eraseFromParent();
3407   Builder.SetInsertPoint(InsertBB);
3408 
3409   return Builder.saveIP();
3410 }
3411 
emitCommonDirectiveEntry(Directive OMPD,Value * EntryCall,BasicBlock * ExitBB,bool Conditional)3412 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveEntry(
3413     Directive OMPD, Value *EntryCall, BasicBlock *ExitBB, bool Conditional) {
3414   // if nothing to do, Return current insertion point.
3415   if (!Conditional || !EntryCall)
3416     return Builder.saveIP();
3417 
3418   BasicBlock *EntryBB = Builder.GetInsertBlock();
3419   Value *CallBool = Builder.CreateIsNotNull(EntryCall);
3420   auto *ThenBB = BasicBlock::Create(M.getContext(), "omp_region.body");
3421   auto *UI = new UnreachableInst(Builder.getContext(), ThenBB);
3422 
3423   // Emit thenBB and set the Builder's insertion point there for
3424   // body generation next. Place the block after the current block.
3425   Function *CurFn = EntryBB->getParent();
3426   CurFn->getBasicBlockList().insertAfter(EntryBB->getIterator(), ThenBB);
3427 
3428   // Move Entry branch to end of ThenBB, and replace with conditional
3429   // branch (If-stmt)
3430   Instruction *EntryBBTI = EntryBB->getTerminator();
3431   Builder.CreateCondBr(CallBool, ThenBB, ExitBB);
3432   EntryBBTI->removeFromParent();
3433   Builder.SetInsertPoint(UI);
3434   Builder.Insert(EntryBBTI);
3435   UI->eraseFromParent();
3436   Builder.SetInsertPoint(ThenBB->getTerminator());
3437 
3438   // return an insertion point to ExitBB.
3439   return IRBuilder<>::InsertPoint(ExitBB, ExitBB->getFirstInsertionPt());
3440 }
3441 
emitCommonDirectiveExit(omp::Directive OMPD,InsertPointTy FinIP,Instruction * ExitCall,bool HasFinalize)3442 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveExit(
3443     omp::Directive OMPD, InsertPointTy FinIP, Instruction *ExitCall,
3444     bool HasFinalize) {
3445 
3446   Builder.restoreIP(FinIP);
3447 
3448   // If there is finalization to do, emit it before the exit call
3449   if (HasFinalize) {
3450     assert(!FinalizationStack.empty() &&
3451            "Unexpected finalization stack state!");
3452 
3453     FinalizationInfo Fi = FinalizationStack.pop_back_val();
3454     assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!");
3455 
3456     Fi.FiniCB(FinIP);
3457 
3458     BasicBlock *FiniBB = FinIP.getBlock();
3459     Instruction *FiniBBTI = FiniBB->getTerminator();
3460 
3461     // set Builder IP for call creation
3462     Builder.SetInsertPoint(FiniBBTI);
3463   }
3464 
3465   if (!ExitCall)
3466     return Builder.saveIP();
3467 
3468   // place the Exitcall as last instruction before Finalization block terminator
3469   ExitCall->removeFromParent();
3470   Builder.Insert(ExitCall);
3471 
3472   return IRBuilder<>::InsertPoint(ExitCall->getParent(),
3473                                   ExitCall->getIterator());
3474 }
3475 
createCopyinClauseBlocks(InsertPointTy IP,Value * MasterAddr,Value * PrivateAddr,llvm::IntegerType * IntPtrTy,bool BranchtoEnd)3476 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCopyinClauseBlocks(
3477     InsertPointTy IP, Value *MasterAddr, Value *PrivateAddr,
3478     llvm::IntegerType *IntPtrTy, bool BranchtoEnd) {
3479   if (!IP.isSet())
3480     return IP;
3481 
3482   IRBuilder<>::InsertPointGuard IPG(Builder);
3483 
3484   // creates the following CFG structure
3485   //	   OMP_Entry : (MasterAddr != PrivateAddr)?
3486   //       F     T
3487   //       |      \
3488   //       |     copin.not.master
3489   //       |      /
3490   //       v     /
3491   //   copyin.not.master.end
3492   //		     |
3493   //         v
3494   //   OMP.Entry.Next
3495 
3496   BasicBlock *OMP_Entry = IP.getBlock();
3497   Function *CurFn = OMP_Entry->getParent();
3498   BasicBlock *CopyBegin =
3499       BasicBlock::Create(M.getContext(), "copyin.not.master", CurFn);
3500   BasicBlock *CopyEnd = nullptr;
3501 
3502   // If entry block is terminated, split to preserve the branch to following
3503   // basic block (i.e. OMP.Entry.Next), otherwise, leave everything as is.
3504   if (isa_and_nonnull<BranchInst>(OMP_Entry->getTerminator())) {
3505     CopyEnd = OMP_Entry->splitBasicBlock(OMP_Entry->getTerminator(),
3506                                          "copyin.not.master.end");
3507     OMP_Entry->getTerminator()->eraseFromParent();
3508   } else {
3509     CopyEnd =
3510         BasicBlock::Create(M.getContext(), "copyin.not.master.end", CurFn);
3511   }
3512 
3513   Builder.SetInsertPoint(OMP_Entry);
3514   Value *MasterPtr = Builder.CreatePtrToInt(MasterAddr, IntPtrTy);
3515   Value *PrivatePtr = Builder.CreatePtrToInt(PrivateAddr, IntPtrTy);
3516   Value *cmp = Builder.CreateICmpNE(MasterPtr, PrivatePtr);
3517   Builder.CreateCondBr(cmp, CopyBegin, CopyEnd);
3518 
3519   Builder.SetInsertPoint(CopyBegin);
3520   if (BranchtoEnd)
3521     Builder.SetInsertPoint(Builder.CreateBr(CopyEnd));
3522 
3523   return Builder.saveIP();
3524 }
3525 
createOMPAlloc(const LocationDescription & Loc,Value * Size,Value * Allocator,std::string Name)3526 CallInst *OpenMPIRBuilder::createOMPAlloc(const LocationDescription &Loc,
3527                                           Value *Size, Value *Allocator,
3528                                           std::string Name) {
3529   IRBuilder<>::InsertPointGuard IPG(Builder);
3530   Builder.restoreIP(Loc.IP);
3531 
3532   uint32_t SrcLocStrSize;
3533   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3534   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3535   Value *ThreadId = getOrCreateThreadID(Ident);
3536   Value *Args[] = {ThreadId, Size, Allocator};
3537 
3538   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_alloc);
3539 
3540   return Builder.CreateCall(Fn, Args, Name);
3541 }
3542 
createOMPFree(const LocationDescription & Loc,Value * Addr,Value * Allocator,std::string Name)3543 CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
3544                                          Value *Addr, Value *Allocator,
3545                                          std::string Name) {
3546   IRBuilder<>::InsertPointGuard IPG(Builder);
3547   Builder.restoreIP(Loc.IP);
3548 
3549   uint32_t SrcLocStrSize;
3550   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3551   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3552   Value *ThreadId = getOrCreateThreadID(Ident);
3553   Value *Args[] = {ThreadId, Addr, Allocator};
3554   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_free);
3555   return Builder.CreateCall(Fn, Args, Name);
3556 }
3557 
createOMPInteropInit(const LocationDescription & Loc,Value * InteropVar,omp::OMPInteropType InteropType,Value * Device,Value * NumDependences,Value * DependenceAddress,bool HaveNowaitClause)3558 CallInst *OpenMPIRBuilder::createOMPInteropInit(
3559     const LocationDescription &Loc, Value *InteropVar,
3560     omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
3561     Value *DependenceAddress, bool HaveNowaitClause) {
3562   IRBuilder<>::InsertPointGuard IPG(Builder);
3563   Builder.restoreIP(Loc.IP);
3564 
3565   uint32_t SrcLocStrSize;
3566   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3567   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3568   Value *ThreadId = getOrCreateThreadID(Ident);
3569   if (Device == nullptr)
3570     Device = ConstantInt::get(Int32, -1);
3571   Constant *InteropTypeVal = ConstantInt::get(Int64, (int)InteropType);
3572   if (NumDependences == nullptr) {
3573     NumDependences = ConstantInt::get(Int32, 0);
3574     PointerType *PointerTypeVar = Type::getInt8PtrTy(M.getContext());
3575     DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
3576   }
3577   Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
3578   Value *Args[] = {
3579       Ident,  ThreadId,       InteropVar,        InteropTypeVal,
3580       Device, NumDependences, DependenceAddress, HaveNowaitClauseVal};
3581 
3582   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_init);
3583 
3584   return Builder.CreateCall(Fn, Args);
3585 }
3586 
createOMPInteropDestroy(const LocationDescription & Loc,Value * InteropVar,Value * Device,Value * NumDependences,Value * DependenceAddress,bool HaveNowaitClause)3587 CallInst *OpenMPIRBuilder::createOMPInteropDestroy(
3588     const LocationDescription &Loc, Value *InteropVar, Value *Device,
3589     Value *NumDependences, Value *DependenceAddress, bool HaveNowaitClause) {
3590   IRBuilder<>::InsertPointGuard IPG(Builder);
3591   Builder.restoreIP(Loc.IP);
3592 
3593   uint32_t SrcLocStrSize;
3594   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3595   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3596   Value *ThreadId = getOrCreateThreadID(Ident);
3597   if (Device == nullptr)
3598     Device = ConstantInt::get(Int32, -1);
3599   if (NumDependences == nullptr) {
3600     NumDependences = ConstantInt::get(Int32, 0);
3601     PointerType *PointerTypeVar = Type::getInt8PtrTy(M.getContext());
3602     DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
3603   }
3604   Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
3605   Value *Args[] = {
3606       Ident,          ThreadId,          InteropVar,         Device,
3607       NumDependences, DependenceAddress, HaveNowaitClauseVal};
3608 
3609   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_destroy);
3610 
3611   return Builder.CreateCall(Fn, Args);
3612 }
3613 
createOMPInteropUse(const LocationDescription & Loc,Value * InteropVar,Value * Device,Value * NumDependences,Value * DependenceAddress,bool HaveNowaitClause)3614 CallInst *OpenMPIRBuilder::createOMPInteropUse(const LocationDescription &Loc,
3615                                                Value *InteropVar, Value *Device,
3616                                                Value *NumDependences,
3617                                                Value *DependenceAddress,
3618                                                bool HaveNowaitClause) {
3619   IRBuilder<>::InsertPointGuard IPG(Builder);
3620   Builder.restoreIP(Loc.IP);
3621   uint32_t SrcLocStrSize;
3622   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3623   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3624   Value *ThreadId = getOrCreateThreadID(Ident);
3625   if (Device == nullptr)
3626     Device = ConstantInt::get(Int32, -1);
3627   if (NumDependences == nullptr) {
3628     NumDependences = ConstantInt::get(Int32, 0);
3629     PointerType *PointerTypeVar = Type::getInt8PtrTy(M.getContext());
3630     DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
3631   }
3632   Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
3633   Value *Args[] = {
3634       Ident,          ThreadId,          InteropVar,         Device,
3635       NumDependences, DependenceAddress, HaveNowaitClauseVal};
3636 
3637   Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_use);
3638 
3639   return Builder.CreateCall(Fn, Args);
3640 }
3641 
createCachedThreadPrivate(const LocationDescription & Loc,llvm::Value * Pointer,llvm::ConstantInt * Size,const llvm::Twine & Name)3642 CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
3643     const LocationDescription &Loc, llvm::Value *Pointer,
3644     llvm::ConstantInt *Size, const llvm::Twine &Name) {
3645   IRBuilder<>::InsertPointGuard IPG(Builder);
3646   Builder.restoreIP(Loc.IP);
3647 
3648   uint32_t SrcLocStrSize;
3649   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3650   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3651   Value *ThreadId = getOrCreateThreadID(Ident);
3652   Constant *ThreadPrivateCache =
3653       getOrCreateOMPInternalVariable(Int8PtrPtr, Name);
3654   llvm::Value *Args[] = {Ident, ThreadId, Pointer, Size, ThreadPrivateCache};
3655 
3656   Function *Fn =
3657       getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_threadprivate_cached);
3658 
3659   return Builder.CreateCall(Fn, Args);
3660 }
3661 
3662 OpenMPIRBuilder::InsertPointTy
createTargetInit(const LocationDescription & Loc,bool IsSPMD,bool RequiresFullRuntime)3663 OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
3664                                   bool RequiresFullRuntime) {
3665   if (!updateToLocation(Loc))
3666     return Loc.IP;
3667 
3668   uint32_t SrcLocStrSize;
3669   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3670   Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3671   ConstantInt *IsSPMDVal = ConstantInt::getSigned(
3672       IntegerType::getInt8Ty(Int8->getContext()),
3673       IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
3674   ConstantInt *UseGenericStateMachine =
3675       ConstantInt::getBool(Int32->getContext(), !IsSPMD);
3676   ConstantInt *RequiresFullRuntimeVal =
3677       ConstantInt::getBool(Int32->getContext(), RequiresFullRuntime);
3678 
3679   Function *Fn = getOrCreateRuntimeFunctionPtr(
3680       omp::RuntimeFunction::OMPRTL___kmpc_target_init);
3681 
3682   CallInst *ThreadKind = Builder.CreateCall(
3683       Fn, {Ident, IsSPMDVal, UseGenericStateMachine, RequiresFullRuntimeVal});
3684 
3685   Value *ExecUserCode = Builder.CreateICmpEQ(
3686       ThreadKind, ConstantInt::get(ThreadKind->getType(), -1),
3687       "exec_user_code");
3688 
3689   // ThreadKind = __kmpc_target_init(...)
3690   // if (ThreadKind == -1)
3691   //   user_code
3692   // else
3693   //   return;
3694 
3695   auto *UI = Builder.CreateUnreachable();
3696   BasicBlock *CheckBB = UI->getParent();
3697   BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(UI, "user_code.entry");
3698 
3699   BasicBlock *WorkerExitBB = BasicBlock::Create(
3700       CheckBB->getContext(), "worker.exit", CheckBB->getParent());
3701   Builder.SetInsertPoint(WorkerExitBB);
3702   Builder.CreateRetVoid();
3703 
3704   auto *CheckBBTI = CheckBB->getTerminator();
3705   Builder.SetInsertPoint(CheckBBTI);
3706   Builder.CreateCondBr(ExecUserCode, UI->getParent(), WorkerExitBB);
3707 
3708   CheckBBTI->eraseFromParent();
3709   UI->eraseFromParent();
3710 
3711   // Continue in the "user_code" block, see diagram above and in
3712   // openmp/libomptarget/deviceRTLs/common/include/target.h .
3713   return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
3714 }
3715 
createTargetDeinit(const LocationDescription & Loc,bool IsSPMD,bool RequiresFullRuntime)3716 void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
3717                                          bool IsSPMD,
3718                                          bool RequiresFullRuntime) {
3719   if (!updateToLocation(Loc))
3720     return;
3721 
3722   uint32_t SrcLocStrSize;
3723   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3724   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3725   ConstantInt *IsSPMDVal = ConstantInt::getSigned(
3726       IntegerType::getInt8Ty(Int8->getContext()),
3727       IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
3728   ConstantInt *RequiresFullRuntimeVal =
3729       ConstantInt::getBool(Int32->getContext(), RequiresFullRuntime);
3730 
3731   Function *Fn = getOrCreateRuntimeFunctionPtr(
3732       omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);
3733 
3734   Builder.CreateCall(Fn, {Ident, IsSPMDVal, RequiresFullRuntimeVal});
3735 }
3736 
getNameWithSeparators(ArrayRef<StringRef> Parts,StringRef FirstSeparator,StringRef Separator)3737 std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
3738                                                    StringRef FirstSeparator,
3739                                                    StringRef Separator) {
3740   SmallString<128> Buffer;
3741   llvm::raw_svector_ostream OS(Buffer);
3742   StringRef Sep = FirstSeparator;
3743   for (StringRef Part : Parts) {
3744     OS << Sep << Part;
3745     Sep = Separator;
3746   }
3747   return OS.str().str();
3748 }
3749 
getOrCreateOMPInternalVariable(llvm::Type * Ty,const llvm::Twine & Name,unsigned AddressSpace)3750 Constant *OpenMPIRBuilder::getOrCreateOMPInternalVariable(
3751     llvm::Type *Ty, const llvm::Twine &Name, unsigned AddressSpace) {
3752   // TODO: Replace the twine arg with stringref to get rid of the conversion
3753   // logic. However This is taken from current implementation in clang as is.
3754   // Since this method is used in many places exclusively for OMP internal use
3755   // we will keep it as is for temporarily until we move all users to the
3756   // builder and then, if possible, fix it everywhere in one go.
3757   SmallString<256> Buffer;
3758   llvm::raw_svector_ostream Out(Buffer);
3759   Out << Name;
3760   StringRef RuntimeName = Out.str();
3761   auto &Elem = *InternalVars.try_emplace(RuntimeName, nullptr).first;
3762   if (Elem.second) {
3763     assert(cast<PointerType>(Elem.second->getType())
3764                ->isOpaqueOrPointeeTypeMatches(Ty) &&
3765            "OMP internal variable has different type than requested");
3766   } else {
3767     // TODO: investigate the appropriate linkage type used for the global
3768     // variable for possibly changing that to internal or private, or maybe
3769     // create different versions of the function for different OMP internal
3770     // variables.
3771     Elem.second = new llvm::GlobalVariable(
3772         M, Ty, /*IsConstant*/ false, llvm::GlobalValue::CommonLinkage,
3773         llvm::Constant::getNullValue(Ty), Elem.first(),
3774         /*InsertBefore=*/nullptr, llvm::GlobalValue::NotThreadLocal,
3775         AddressSpace);
3776   }
3777 
3778   return Elem.second;
3779 }
3780 
getOMPCriticalRegionLock(StringRef CriticalName)3781 Value *OpenMPIRBuilder::getOMPCriticalRegionLock(StringRef CriticalName) {
3782   std::string Prefix = Twine("gomp_critical_user_", CriticalName).str();
3783   std::string Name = getNameWithSeparators({Prefix, "var"}, ".", ".");
3784   return getOrCreateOMPInternalVariable(KmpCriticalNameTy, Name);
3785 }
3786 
3787 GlobalVariable *
createOffloadMaptypes(SmallVectorImpl<uint64_t> & Mappings,std::string VarName)3788 OpenMPIRBuilder::createOffloadMaptypes(SmallVectorImpl<uint64_t> &Mappings,
3789                                        std::string VarName) {
3790   llvm::Constant *MaptypesArrayInit =
3791       llvm::ConstantDataArray::get(M.getContext(), Mappings);
3792   auto *MaptypesArrayGlobal = new llvm::GlobalVariable(
3793       M, MaptypesArrayInit->getType(),
3794       /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MaptypesArrayInit,
3795       VarName);
3796   MaptypesArrayGlobal->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
3797   return MaptypesArrayGlobal;
3798 }
3799 
createMapperAllocas(const LocationDescription & Loc,InsertPointTy AllocaIP,unsigned NumOperands,struct MapperAllocas & MapperAllocas)3800 void OpenMPIRBuilder::createMapperAllocas(const LocationDescription &Loc,
3801                                           InsertPointTy AllocaIP,
3802                                           unsigned NumOperands,
3803                                           struct MapperAllocas &MapperAllocas) {
3804   if (!updateToLocation(Loc))
3805     return;
3806 
3807   auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
3808   auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
3809   Builder.restoreIP(AllocaIP);
3810   AllocaInst *ArgsBase = Builder.CreateAlloca(ArrI8PtrTy);
3811   AllocaInst *Args = Builder.CreateAlloca(ArrI8PtrTy);
3812   AllocaInst *ArgSizes = Builder.CreateAlloca(ArrI64Ty);
3813   Builder.restoreIP(Loc.IP);
3814   MapperAllocas.ArgsBase = ArgsBase;
3815   MapperAllocas.Args = Args;
3816   MapperAllocas.ArgSizes = ArgSizes;
3817 }
3818 
emitMapperCall(const LocationDescription & Loc,Function * MapperFunc,Value * SrcLocInfo,Value * MaptypesArg,Value * MapnamesArg,struct MapperAllocas & MapperAllocas,int64_t DeviceID,unsigned NumOperands)3819 void OpenMPIRBuilder::emitMapperCall(const LocationDescription &Loc,
3820                                      Function *MapperFunc, Value *SrcLocInfo,
3821                                      Value *MaptypesArg, Value *MapnamesArg,
3822                                      struct MapperAllocas &MapperAllocas,
3823                                      int64_t DeviceID, unsigned NumOperands) {
3824   if (!updateToLocation(Loc))
3825     return;
3826 
3827   auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
3828   auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
3829   Value *ArgsBaseGEP =
3830       Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.ArgsBase,
3831                                 {Builder.getInt32(0), Builder.getInt32(0)});
3832   Value *ArgsGEP =
3833       Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.Args,
3834                                 {Builder.getInt32(0), Builder.getInt32(0)});
3835   Value *ArgSizesGEP =
3836       Builder.CreateInBoundsGEP(ArrI64Ty, MapperAllocas.ArgSizes,
3837                                 {Builder.getInt32(0), Builder.getInt32(0)});
3838   Value *NullPtr = Constant::getNullValue(Int8Ptr->getPointerTo());
3839   Builder.CreateCall(MapperFunc,
3840                      {SrcLocInfo, Builder.getInt64(DeviceID),
3841                       Builder.getInt32(NumOperands), ArgsBaseGEP, ArgsGEP,
3842                       ArgSizesGEP, MaptypesArg, MapnamesArg, NullPtr});
3843 }
3844 
checkAndEmitFlushAfterAtomic(const LocationDescription & Loc,llvm::AtomicOrdering AO,AtomicKind AK)3845 bool OpenMPIRBuilder::checkAndEmitFlushAfterAtomic(
3846     const LocationDescription &Loc, llvm::AtomicOrdering AO, AtomicKind AK) {
3847   assert(!(AO == AtomicOrdering::NotAtomic ||
3848            AO == llvm::AtomicOrdering::Unordered) &&
3849          "Unexpected Atomic Ordering.");
3850 
3851   bool Flush = false;
3852   llvm::AtomicOrdering FlushAO = AtomicOrdering::Monotonic;
3853 
3854   switch (AK) {
3855   case Read:
3856     if (AO == AtomicOrdering::Acquire || AO == AtomicOrdering::AcquireRelease ||
3857         AO == AtomicOrdering::SequentiallyConsistent) {
3858       FlushAO = AtomicOrdering::Acquire;
3859       Flush = true;
3860     }
3861     break;
3862   case Write:
3863   case Compare:
3864   case Update:
3865     if (AO == AtomicOrdering::Release || AO == AtomicOrdering::AcquireRelease ||
3866         AO == AtomicOrdering::SequentiallyConsistent) {
3867       FlushAO = AtomicOrdering::Release;
3868       Flush = true;
3869     }
3870     break;
3871   case Capture:
3872     switch (AO) {
3873     case AtomicOrdering::Acquire:
3874       FlushAO = AtomicOrdering::Acquire;
3875       Flush = true;
3876       break;
3877     case AtomicOrdering::Release:
3878       FlushAO = AtomicOrdering::Release;
3879       Flush = true;
3880       break;
3881     case AtomicOrdering::AcquireRelease:
3882     case AtomicOrdering::SequentiallyConsistent:
3883       FlushAO = AtomicOrdering::AcquireRelease;
3884       Flush = true;
3885       break;
3886     default:
3887       // do nothing - leave silently.
3888       break;
3889     }
3890   }
3891 
3892   if (Flush) {
3893     // Currently Flush RT call still doesn't take memory_ordering, so for when
3894     // that happens, this tries to do the resolution of which atomic ordering
3895     // to use with but issue the flush call
3896     // TODO: pass `FlushAO` after memory ordering support is added
3897     (void)FlushAO;
3898     emitFlush(Loc);
3899   }
3900 
3901   // for AO == AtomicOrdering::Monotonic and  all other case combinations
3902   // do nothing
3903   return Flush;
3904 }
3905 
3906 OpenMPIRBuilder::InsertPointTy
createAtomicRead(const LocationDescription & Loc,AtomicOpValue & X,AtomicOpValue & V,AtomicOrdering AO)3907 OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
3908                                   AtomicOpValue &X, AtomicOpValue &V,
3909                                   AtomicOrdering AO) {
3910   if (!updateToLocation(Loc))
3911     return Loc.IP;
3912 
3913   Type *XTy = X.Var->getType();
3914   assert(XTy->isPointerTy() && "OMP Atomic expects a pointer to target memory");
3915   Type *XElemTy = X.ElemTy;
3916   assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
3917           XElemTy->isPointerTy()) &&
3918          "OMP atomic read expected a scalar type");
3919 
3920   Value *XRead = nullptr;
3921 
3922   if (XElemTy->isIntegerTy()) {
3923     LoadInst *XLD =
3924         Builder.CreateLoad(XElemTy, X.Var, X.IsVolatile, "omp.atomic.read");
3925     XLD->setAtomic(AO);
3926     XRead = cast<Value>(XLD);
3927   } else {
3928     // We need to bitcast and perform atomic op as integer
3929     unsigned Addrspace = cast<PointerType>(XTy)->getAddressSpace();
3930     IntegerType *IntCastTy =
3931         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
3932     Value *XBCast = Builder.CreateBitCast(
3933         X.Var, IntCastTy->getPointerTo(Addrspace), "atomic.src.int.cast");
3934     LoadInst *XLoad =
3935         Builder.CreateLoad(IntCastTy, XBCast, X.IsVolatile, "omp.atomic.load");
3936     XLoad->setAtomic(AO);
3937     if (XElemTy->isFloatingPointTy()) {
3938       XRead = Builder.CreateBitCast(XLoad, XElemTy, "atomic.flt.cast");
3939     } else {
3940       XRead = Builder.CreateIntToPtr(XLoad, XElemTy, "atomic.ptr.cast");
3941     }
3942   }
3943   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Read);
3944   Builder.CreateStore(XRead, V.Var, V.IsVolatile);
3945   return Builder.saveIP();
3946 }
3947 
3948 OpenMPIRBuilder::InsertPointTy
createAtomicWrite(const LocationDescription & Loc,AtomicOpValue & X,Value * Expr,AtomicOrdering AO)3949 OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
3950                                    AtomicOpValue &X, Value *Expr,
3951                                    AtomicOrdering AO) {
3952   if (!updateToLocation(Loc))
3953     return Loc.IP;
3954 
3955   Type *XTy = X.Var->getType();
3956   assert(XTy->isPointerTy() && "OMP Atomic expects a pointer to target memory");
3957   Type *XElemTy = X.ElemTy;
3958   assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
3959           XElemTy->isPointerTy()) &&
3960          "OMP atomic write expected a scalar type");
3961 
3962   if (XElemTy->isIntegerTy()) {
3963     StoreInst *XSt = Builder.CreateStore(Expr, X.Var, X.IsVolatile);
3964     XSt->setAtomic(AO);
3965   } else {
3966     // We need to bitcast and perform atomic op as integers
3967     unsigned Addrspace = cast<PointerType>(XTy)->getAddressSpace();
3968     IntegerType *IntCastTy =
3969         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
3970     Value *XBCast = Builder.CreateBitCast(
3971         X.Var, IntCastTy->getPointerTo(Addrspace), "atomic.dst.int.cast");
3972     Value *ExprCast =
3973         Builder.CreateBitCast(Expr, IntCastTy, "atomic.src.int.cast");
3974     StoreInst *XSt = Builder.CreateStore(ExprCast, XBCast, X.IsVolatile);
3975     XSt->setAtomic(AO);
3976   }
3977 
3978   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Write);
3979   return Builder.saveIP();
3980 }
3981 
createAtomicUpdate(const LocationDescription & Loc,InsertPointTy AllocaIP,AtomicOpValue & X,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool IsXBinopExpr)3982 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicUpdate(
3983     const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
3984     Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
3985     AtomicUpdateCallbackTy &UpdateOp, bool IsXBinopExpr) {
3986   assert(!isConflictIP(Loc.IP, AllocaIP) && "IPs must not be ambiguous");
3987   if (!updateToLocation(Loc))
3988     return Loc.IP;
3989 
3990   LLVM_DEBUG({
3991     Type *XTy = X.Var->getType();
3992     assert(XTy->isPointerTy() &&
3993            "OMP Atomic expects a pointer to target memory");
3994     Type *XElemTy = X.ElemTy;
3995     assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
3996             XElemTy->isPointerTy()) &&
3997            "OMP atomic update expected a scalar type");
3998     assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
3999            (RMWOp != AtomicRMWInst::UMax) && (RMWOp != AtomicRMWInst::UMin) &&
4000            "OpenMP atomic does not support LT or GT operations");
4001   });
4002 
4003   emitAtomicUpdate(AllocaIP, X.Var, X.ElemTy, Expr, AO, RMWOp, UpdateOp,
4004                    X.IsVolatile, IsXBinopExpr);
4005   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Update);
4006   return Builder.saveIP();
4007 }
4008 
emitRMWOpAsInstruction(Value * Src1,Value * Src2,AtomicRMWInst::BinOp RMWOp)4009 Value *OpenMPIRBuilder::emitRMWOpAsInstruction(Value *Src1, Value *Src2,
4010                                                AtomicRMWInst::BinOp RMWOp) {
4011   switch (RMWOp) {
4012   case AtomicRMWInst::Add:
4013     return Builder.CreateAdd(Src1, Src2);
4014   case AtomicRMWInst::Sub:
4015     return Builder.CreateSub(Src1, Src2);
4016   case AtomicRMWInst::And:
4017     return Builder.CreateAnd(Src1, Src2);
4018   case AtomicRMWInst::Nand:
4019     return Builder.CreateNeg(Builder.CreateAnd(Src1, Src2));
4020   case AtomicRMWInst::Or:
4021     return Builder.CreateOr(Src1, Src2);
4022   case AtomicRMWInst::Xor:
4023     return Builder.CreateXor(Src1, Src2);
4024   case AtomicRMWInst::Xchg:
4025   case AtomicRMWInst::FAdd:
4026   case AtomicRMWInst::FSub:
4027   case AtomicRMWInst::BAD_BINOP:
4028   case AtomicRMWInst::Max:
4029   case AtomicRMWInst::Min:
4030   case AtomicRMWInst::UMax:
4031   case AtomicRMWInst::UMin:
4032   case AtomicRMWInst::FMax:
4033   case AtomicRMWInst::FMin:
4034     llvm_unreachable("Unsupported atomic update operation");
4035   }
4036   llvm_unreachable("Unsupported atomic update operation");
4037 }
4038 
emitAtomicUpdate(InsertPointTy AllocaIP,Value * X,Type * XElemTy,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool VolatileX,bool IsXBinopExpr)4039 std::pair<Value *, Value *> OpenMPIRBuilder::emitAtomicUpdate(
4040     InsertPointTy AllocaIP, Value *X, Type *XElemTy, Value *Expr,
4041     AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
4042     AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr) {
4043   // TODO: handle the case where XElemTy is not byte-sized or not a power of 2
4044   // or a complex datatype.
4045   bool emitRMWOp = false;
4046   switch (RMWOp) {
4047   case AtomicRMWInst::Add:
4048   case AtomicRMWInst::And:
4049   case AtomicRMWInst::Nand:
4050   case AtomicRMWInst::Or:
4051   case AtomicRMWInst::Xor:
4052   case AtomicRMWInst::Xchg:
4053     emitRMWOp = XElemTy;
4054     break;
4055   case AtomicRMWInst::Sub:
4056     emitRMWOp = (IsXBinopExpr && XElemTy);
4057     break;
4058   default:
4059     emitRMWOp = false;
4060   }
4061   emitRMWOp &= XElemTy->isIntegerTy();
4062 
4063   std::pair<Value *, Value *> Res;
4064   if (emitRMWOp) {
4065     Res.first = Builder.CreateAtomicRMW(RMWOp, X, Expr, llvm::MaybeAlign(), AO);
4066     // not needed except in case of postfix captures. Generate anyway for
4067     // consistency with the else part. Will be removed with any DCE pass.
4068     // AtomicRMWInst::Xchg does not have a coressponding instruction.
4069     if (RMWOp == AtomicRMWInst::Xchg)
4070       Res.second = Res.first;
4071     else
4072       Res.second = emitRMWOpAsInstruction(Res.first, Expr, RMWOp);
4073   } else {
4074     unsigned Addrspace = cast<PointerType>(X->getType())->getAddressSpace();
4075     IntegerType *IntCastTy =
4076         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
4077     Value *XBCast =
4078         Builder.CreateBitCast(X, IntCastTy->getPointerTo(Addrspace));
4079     LoadInst *OldVal =
4080         Builder.CreateLoad(IntCastTy, XBCast, X->getName() + ".atomic.load");
4081     OldVal->setAtomic(AO);
4082     // CurBB
4083     // |     /---\
4084 		// ContBB    |
4085     // |     \---/
4086     // ExitBB
4087     BasicBlock *CurBB = Builder.GetInsertBlock();
4088     Instruction *CurBBTI = CurBB->getTerminator();
4089     CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
4090     BasicBlock *ExitBB =
4091         CurBB->splitBasicBlock(CurBBTI, X->getName() + ".atomic.exit");
4092     BasicBlock *ContBB = CurBB->splitBasicBlock(CurBB->getTerminator(),
4093                                                 X->getName() + ".atomic.cont");
4094     ContBB->getTerminator()->eraseFromParent();
4095     Builder.restoreIP(AllocaIP);
4096     AllocaInst *NewAtomicAddr = Builder.CreateAlloca(XElemTy);
4097     NewAtomicAddr->setName(X->getName() + "x.new.val");
4098     Builder.SetInsertPoint(ContBB);
4099     llvm::PHINode *PHI = Builder.CreatePHI(OldVal->getType(), 2);
4100     PHI->addIncoming(OldVal, CurBB);
4101     IntegerType *NewAtomicCastTy =
4102         IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
4103     bool IsIntTy = XElemTy->isIntegerTy();
4104     Value *NewAtomicIntAddr =
4105         (IsIntTy)
4106             ? NewAtomicAddr
4107             : Builder.CreateBitCast(NewAtomicAddr,
4108                                     NewAtomicCastTy->getPointerTo(Addrspace));
4109     Value *OldExprVal = PHI;
4110     if (!IsIntTy) {
4111       if (XElemTy->isFloatingPointTy()) {
4112         OldExprVal = Builder.CreateBitCast(PHI, XElemTy,
4113                                            X->getName() + ".atomic.fltCast");
4114       } else {
4115         OldExprVal = Builder.CreateIntToPtr(PHI, XElemTy,
4116                                             X->getName() + ".atomic.ptrCast");
4117       }
4118     }
4119 
4120     Value *Upd = UpdateOp(OldExprVal, Builder);
4121     Builder.CreateStore(Upd, NewAtomicAddr);
4122     LoadInst *DesiredVal = Builder.CreateLoad(IntCastTy, NewAtomicIntAddr);
4123     Value *XAddr =
4124         (IsIntTy)
4125             ? X
4126             : Builder.CreateBitCast(X, IntCastTy->getPointerTo(Addrspace));
4127     AtomicOrdering Failure =
4128         llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
4129     AtomicCmpXchgInst *Result = Builder.CreateAtomicCmpXchg(
4130         XAddr, PHI, DesiredVal, llvm::MaybeAlign(), AO, Failure);
4131     Result->setVolatile(VolatileX);
4132     Value *PreviousVal = Builder.CreateExtractValue(Result, /*Idxs=*/0);
4133     Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
4134     PHI->addIncoming(PreviousVal, Builder.GetInsertBlock());
4135     Builder.CreateCondBr(SuccessFailureVal, ExitBB, ContBB);
4136 
4137     Res.first = OldExprVal;
4138     Res.second = Upd;
4139 
4140     // set Insertion point in exit block
4141     if (UnreachableInst *ExitTI =
4142             dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
4143       CurBBTI->eraseFromParent();
4144       Builder.SetInsertPoint(ExitBB);
4145     } else {
4146       Builder.SetInsertPoint(ExitTI);
4147     }
4148   }
4149 
4150   return Res;
4151 }
4152 
createAtomicCapture(const LocationDescription & Loc,InsertPointTy AllocaIP,AtomicOpValue & X,AtomicOpValue & V,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool UpdateExpr,bool IsPostfixUpdate,bool IsXBinopExpr)4153 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCapture(
4154     const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
4155     AtomicOpValue &V, Value *Expr, AtomicOrdering AO,
4156     AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp,
4157     bool UpdateExpr, bool IsPostfixUpdate, bool IsXBinopExpr) {
4158   if (!updateToLocation(Loc))
4159     return Loc.IP;
4160 
4161   LLVM_DEBUG({
4162     Type *XTy = X.Var->getType();
4163     assert(XTy->isPointerTy() &&
4164            "OMP Atomic expects a pointer to target memory");
4165     Type *XElemTy = X.ElemTy;
4166     assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
4167             XElemTy->isPointerTy()) &&
4168            "OMP atomic capture expected a scalar type");
4169     assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
4170            "OpenMP atomic does not support LT or GT operations");
4171   });
4172 
4173   // If UpdateExpr is 'x' updated with some `expr` not based on 'x',
4174   // 'x' is simply atomically rewritten with 'expr'.
4175   AtomicRMWInst::BinOp AtomicOp = (UpdateExpr ? RMWOp : AtomicRMWInst::Xchg);
4176   std::pair<Value *, Value *> Result =
4177       emitAtomicUpdate(AllocaIP, X.Var, X.ElemTy, Expr, AO, AtomicOp, UpdateOp,
4178                        X.IsVolatile, IsXBinopExpr);
4179 
4180   Value *CapturedVal = (IsPostfixUpdate ? Result.first : Result.second);
4181   Builder.CreateStore(CapturedVal, V.Var, V.IsVolatile);
4182 
4183   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Capture);
4184   return Builder.saveIP();
4185 }
4186 
createAtomicCompare(const LocationDescription & Loc,AtomicOpValue & X,AtomicOpValue & V,AtomicOpValue & R,Value * E,Value * D,AtomicOrdering AO,omp::OMPAtomicCompareOp Op,bool IsXBinopExpr,bool IsPostfixUpdate,bool IsFailOnly)4187 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
4188     const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
4189     AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
4190     omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
4191     bool IsFailOnly) {
4192 
4193   if (!updateToLocation(Loc))
4194     return Loc.IP;
4195 
4196   assert(X.Var->getType()->isPointerTy() &&
4197          "OMP atomic expects a pointer to target memory");
4198   // compare capture
4199   if (V.Var) {
4200     assert(V.Var->getType()->isPointerTy() && "v.var must be of pointer type");
4201     assert(V.ElemTy == X.ElemTy && "x and v must be of same type");
4202   }
4203 
4204   bool IsInteger = E->getType()->isIntegerTy();
4205 
4206   if (Op == OMPAtomicCompareOp::EQ) {
4207     AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
4208     AtomicCmpXchgInst *Result = nullptr;
4209     if (!IsInteger) {
4210       unsigned Addrspace =
4211           cast<PointerType>(X.Var->getType())->getAddressSpace();
4212       IntegerType *IntCastTy =
4213           IntegerType::get(M.getContext(), X.ElemTy->getScalarSizeInBits());
4214       Value *XBCast =
4215           Builder.CreateBitCast(X.Var, IntCastTy->getPointerTo(Addrspace));
4216       Value *EBCast = Builder.CreateBitCast(E, IntCastTy);
4217       Value *DBCast = Builder.CreateBitCast(D, IntCastTy);
4218       Result = Builder.CreateAtomicCmpXchg(XBCast, EBCast, DBCast, MaybeAlign(),
4219                                            AO, Failure);
4220     } else {
4221       Result =
4222           Builder.CreateAtomicCmpXchg(X.Var, E, D, MaybeAlign(), AO, Failure);
4223     }
4224 
4225     if (V.Var) {
4226       Value *OldValue = Builder.CreateExtractValue(Result, /*Idxs=*/0);
4227       if (!IsInteger)
4228         OldValue = Builder.CreateBitCast(OldValue, X.ElemTy);
4229       assert(OldValue->getType() == V.ElemTy &&
4230              "OldValue and V must be of same type");
4231       if (IsPostfixUpdate) {
4232         Builder.CreateStore(OldValue, V.Var, V.IsVolatile);
4233       } else {
4234         Value *SuccessOrFail = Builder.CreateExtractValue(Result, /*Idxs=*/1);
4235         if (IsFailOnly) {
4236           // CurBB----
4237           //   |     |
4238           //   v     |
4239           // ContBB  |
4240           //   |     |
4241           //   v     |
4242           // ExitBB <-
4243           //
4244           // where ContBB only contains the store of old value to 'v'.
4245           BasicBlock *CurBB = Builder.GetInsertBlock();
4246           Instruction *CurBBTI = CurBB->getTerminator();
4247           CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
4248           BasicBlock *ExitBB = CurBB->splitBasicBlock(
4249               CurBBTI, X.Var->getName() + ".atomic.exit");
4250           BasicBlock *ContBB = CurBB->splitBasicBlock(
4251               CurBB->getTerminator(), X.Var->getName() + ".atomic.cont");
4252           ContBB->getTerminator()->eraseFromParent();
4253           CurBB->getTerminator()->eraseFromParent();
4254 
4255           Builder.CreateCondBr(SuccessOrFail, ExitBB, ContBB);
4256 
4257           Builder.SetInsertPoint(ContBB);
4258           Builder.CreateStore(OldValue, V.Var);
4259           Builder.CreateBr(ExitBB);
4260 
4261           if (UnreachableInst *ExitTI =
4262                   dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
4263             CurBBTI->eraseFromParent();
4264             Builder.SetInsertPoint(ExitBB);
4265           } else {
4266             Builder.SetInsertPoint(ExitTI);
4267           }
4268         } else {
4269           Value *CapturedValue =
4270               Builder.CreateSelect(SuccessOrFail, E, OldValue);
4271           Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
4272         }
4273       }
4274     }
4275     // The comparison result has to be stored.
4276     if (R.Var) {
4277       assert(R.Var->getType()->isPointerTy() &&
4278              "r.var must be of pointer type");
4279       assert(R.ElemTy->isIntegerTy() && "r must be of integral type");
4280 
4281       Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
4282       Value *ResultCast = R.IsSigned
4283                               ? Builder.CreateSExt(SuccessFailureVal, R.ElemTy)
4284                               : Builder.CreateZExt(SuccessFailureVal, R.ElemTy);
4285       Builder.CreateStore(ResultCast, R.Var, R.IsVolatile);
4286     }
4287   } else {
4288     assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) &&
4289            "Op should be either max or min at this point");
4290     assert(!IsFailOnly && "IsFailOnly is only valid when the comparison is ==");
4291 
4292     // Reverse the ordop as the OpenMP forms are different from LLVM forms.
4293     // Let's take max as example.
4294     // OpenMP form:
4295     // x = x > expr ? expr : x;
4296     // LLVM form:
4297     // *ptr = *ptr > val ? *ptr : val;
4298     // We need to transform to LLVM form.
4299     // x = x <= expr ? x : expr;
4300     AtomicRMWInst::BinOp NewOp;
4301     if (IsXBinopExpr) {
4302       if (IsInteger) {
4303         if (X.IsSigned)
4304           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min
4305                                                 : AtomicRMWInst::Max;
4306         else
4307           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin
4308                                                 : AtomicRMWInst::UMax;
4309       } else {
4310         NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMin
4311                                               : AtomicRMWInst::FMax;
4312       }
4313     } else {
4314       if (IsInteger) {
4315         if (X.IsSigned)
4316           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max
4317                                                 : AtomicRMWInst::Min;
4318         else
4319           NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
4320                                                 : AtomicRMWInst::UMin;
4321       } else {
4322         NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMax
4323                                               : AtomicRMWInst::FMin;
4324       }
4325     }
4326 
4327     AtomicRMWInst *OldValue =
4328         Builder.CreateAtomicRMW(NewOp, X.Var, E, MaybeAlign(), AO);
4329     if (V.Var) {
4330       Value *CapturedValue = nullptr;
4331       if (IsPostfixUpdate) {
4332         CapturedValue = OldValue;
4333       } else {
4334         CmpInst::Predicate Pred;
4335         switch (NewOp) {
4336         case AtomicRMWInst::Max:
4337           Pred = CmpInst::ICMP_SGT;
4338           break;
4339         case AtomicRMWInst::UMax:
4340           Pred = CmpInst::ICMP_UGT;
4341           break;
4342         case AtomicRMWInst::FMax:
4343           Pred = CmpInst::FCMP_OGT;
4344           break;
4345         case AtomicRMWInst::Min:
4346           Pred = CmpInst::ICMP_SLT;
4347           break;
4348         case AtomicRMWInst::UMin:
4349           Pred = CmpInst::ICMP_ULT;
4350           break;
4351         case AtomicRMWInst::FMin:
4352           Pred = CmpInst::FCMP_OLT;
4353           break;
4354         default:
4355           llvm_unreachable("unexpected comparison op");
4356         }
4357         Value *NonAtomicCmp = Builder.CreateCmp(Pred, OldValue, E);
4358         CapturedValue = Builder.CreateSelect(NonAtomicCmp, E, OldValue);
4359       }
4360       Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
4361     }
4362   }
4363 
4364   checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Compare);
4365 
4366   return Builder.saveIP();
4367 }
4368 
4369 GlobalVariable *
createOffloadMapnames(SmallVectorImpl<llvm::Constant * > & Names,std::string VarName)4370 OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
4371                                        std::string VarName) {
4372   llvm::Constant *MapNamesArrayInit = llvm::ConstantArray::get(
4373       llvm::ArrayType::get(
4374           llvm::Type::getInt8Ty(M.getContext())->getPointerTo(), Names.size()),
4375       Names);
4376   auto *MapNamesArrayGlobal = new llvm::GlobalVariable(
4377       M, MapNamesArrayInit->getType(),
4378       /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MapNamesArrayInit,
4379       VarName);
4380   return MapNamesArrayGlobal;
4381 }
4382 
4383 // Create all simple and struct types exposed by the runtime and remember
4384 // the llvm::PointerTypes of them for easy access later.
initializeTypes(Module & M)4385 void OpenMPIRBuilder::initializeTypes(Module &M) {
4386   LLVMContext &Ctx = M.getContext();
4387   StructType *T;
4388 #define OMP_TYPE(VarName, InitValue) VarName = InitValue;
4389 #define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize)                             \
4390   VarName##Ty = ArrayType::get(ElemTy, ArraySize);                             \
4391   VarName##PtrTy = PointerType::getUnqual(VarName##Ty);
4392 #define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...)                  \
4393   VarName = FunctionType::get(ReturnType, {__VA_ARGS__}, IsVarArg);            \
4394   VarName##Ptr = PointerType::getUnqual(VarName);
4395 #define OMP_STRUCT_TYPE(VarName, StructName, ...)                              \
4396   T = StructType::getTypeByName(Ctx, StructName);                              \
4397   if (!T)                                                                      \
4398     T = StructType::create(Ctx, {__VA_ARGS__}, StructName);                    \
4399   VarName = T;                                                                 \
4400   VarName##Ptr = PointerType::getUnqual(T);
4401 #include "llvm/Frontend/OpenMP/OMPKinds.def"
4402 }
4403 
collectBlocks(SmallPtrSetImpl<BasicBlock * > & BlockSet,SmallVectorImpl<BasicBlock * > & BlockVector)4404 void OpenMPIRBuilder::OutlineInfo::collectBlocks(
4405     SmallPtrSetImpl<BasicBlock *> &BlockSet,
4406     SmallVectorImpl<BasicBlock *> &BlockVector) {
4407   SmallVector<BasicBlock *, 32> Worklist;
4408   BlockSet.insert(EntryBB);
4409   BlockSet.insert(ExitBB);
4410 
4411   Worklist.push_back(EntryBB);
4412   while (!Worklist.empty()) {
4413     BasicBlock *BB = Worklist.pop_back_val();
4414     BlockVector.push_back(BB);
4415     for (BasicBlock *SuccBB : successors(BB))
4416       if (BlockSet.insert(SuccBB).second)
4417         Worklist.push_back(SuccBB);
4418   }
4419 }
4420 
collectControlBlocks(SmallVectorImpl<BasicBlock * > & BBs)4421 void CanonicalLoopInfo::collectControlBlocks(
4422     SmallVectorImpl<BasicBlock *> &BBs) {
4423   // We only count those BBs as control block for which we do not need to
4424   // reverse the CFG, i.e. not the loop body which can contain arbitrary control
4425   // flow. For consistency, this also means we do not add the Body block, which
4426   // is just the entry to the body code.
4427   BBs.reserve(BBs.size() + 6);
4428   BBs.append({getPreheader(), Header, Cond, Latch, Exit, getAfter()});
4429 }
4430 
getPreheader() const4431 BasicBlock *CanonicalLoopInfo::getPreheader() const {
4432   assert(isValid() && "Requires a valid canonical loop");
4433   for (BasicBlock *Pred : predecessors(Header)) {
4434     if (Pred != Latch)
4435       return Pred;
4436   }
4437   llvm_unreachable("Missing preheader");
4438 }
4439 
setTripCount(Value * TripCount)4440 void CanonicalLoopInfo::setTripCount(Value *TripCount) {
4441   assert(isValid() && "Requires a valid canonical loop");
4442 
4443   Instruction *CmpI = &getCond()->front();
4444   assert(isa<CmpInst>(CmpI) && "First inst must compare IV with TripCount");
4445   CmpI->setOperand(1, TripCount);
4446 
4447 #ifndef NDEBUG
4448   assertOK();
4449 #endif
4450 }
4451 
mapIndVar(llvm::function_ref<Value * (Instruction *)> Updater)4452 void CanonicalLoopInfo::mapIndVar(
4453     llvm::function_ref<Value *(Instruction *)> Updater) {
4454   assert(isValid() && "Requires a valid canonical loop");
4455 
4456   Instruction *OldIV = getIndVar();
4457 
4458   // Record all uses excluding those introduced by the updater. Uses by the
4459   // CanonicalLoopInfo itself to keep track of the number of iterations are
4460   // excluded.
4461   SmallVector<Use *> ReplacableUses;
4462   for (Use &U : OldIV->uses()) {
4463     auto *User = dyn_cast<Instruction>(U.getUser());
4464     if (!User)
4465       continue;
4466     if (User->getParent() == getCond())
4467       continue;
4468     if (User->getParent() == getLatch())
4469       continue;
4470     ReplacableUses.push_back(&U);
4471   }
4472 
4473   // Run the updater that may introduce new uses
4474   Value *NewIV = Updater(OldIV);
4475 
4476   // Replace the old uses with the value returned by the updater.
4477   for (Use *U : ReplacableUses)
4478     U->set(NewIV);
4479 
4480 #ifndef NDEBUG
4481   assertOK();
4482 #endif
4483 }
4484 
assertOK() const4485 void CanonicalLoopInfo::assertOK() const {
4486 #ifndef NDEBUG
4487   // No constraints if this object currently does not describe a loop.
4488   if (!isValid())
4489     return;
4490 
4491   BasicBlock *Preheader = getPreheader();
4492   BasicBlock *Body = getBody();
4493   BasicBlock *After = getAfter();
4494 
4495   // Verify standard control-flow we use for OpenMP loops.
4496   assert(Preheader);
4497   assert(isa<BranchInst>(Preheader->getTerminator()) &&
4498          "Preheader must terminate with unconditional branch");
4499   assert(Preheader->getSingleSuccessor() == Header &&
4500          "Preheader must jump to header");
4501 
4502   assert(Header);
4503   assert(isa<BranchInst>(Header->getTerminator()) &&
4504          "Header must terminate with unconditional branch");
4505   assert(Header->getSingleSuccessor() == Cond &&
4506          "Header must jump to exiting block");
4507 
4508   assert(Cond);
4509   assert(Cond->getSinglePredecessor() == Header &&
4510          "Exiting block only reachable from header");
4511 
4512   assert(isa<BranchInst>(Cond->getTerminator()) &&
4513          "Exiting block must terminate with conditional branch");
4514   assert(size(successors(Cond)) == 2 &&
4515          "Exiting block must have two successors");
4516   assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(0) == Body &&
4517          "Exiting block's first successor jump to the body");
4518   assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(1) == Exit &&
4519          "Exiting block's second successor must exit the loop");
4520 
4521   assert(Body);
4522   assert(Body->getSinglePredecessor() == Cond &&
4523          "Body only reachable from exiting block");
4524   assert(!isa<PHINode>(Body->front()));
4525 
4526   assert(Latch);
4527   assert(isa<BranchInst>(Latch->getTerminator()) &&
4528          "Latch must terminate with unconditional branch");
4529   assert(Latch->getSingleSuccessor() == Header && "Latch must jump to header");
4530   // TODO: To support simple redirecting of the end of the body code that has
4531   // multiple; introduce another auxiliary basic block like preheader and after.
4532   assert(Latch->getSinglePredecessor() != nullptr);
4533   assert(!isa<PHINode>(Latch->front()));
4534 
4535   assert(Exit);
4536   assert(isa<BranchInst>(Exit->getTerminator()) &&
4537          "Exit block must terminate with unconditional branch");
4538   assert(Exit->getSingleSuccessor() == After &&
4539          "Exit block must jump to after block");
4540 
4541   assert(After);
4542   assert(After->getSinglePredecessor() == Exit &&
4543          "After block only reachable from exit block");
4544   assert(After->empty() || !isa<PHINode>(After->front()));
4545 
4546   Instruction *IndVar = getIndVar();
4547   assert(IndVar && "Canonical induction variable not found?");
4548   assert(isa<IntegerType>(IndVar->getType()) &&
4549          "Induction variable must be an integer");
4550   assert(cast<PHINode>(IndVar)->getParent() == Header &&
4551          "Induction variable must be a PHI in the loop header");
4552   assert(cast<PHINode>(IndVar)->getIncomingBlock(0) == Preheader);
4553   assert(
4554       cast<ConstantInt>(cast<PHINode>(IndVar)->getIncomingValue(0))->isZero());
4555   assert(cast<PHINode>(IndVar)->getIncomingBlock(1) == Latch);
4556 
4557   auto *NextIndVar = cast<PHINode>(IndVar)->getIncomingValue(1);
4558   assert(cast<Instruction>(NextIndVar)->getParent() == Latch);
4559   assert(cast<BinaryOperator>(NextIndVar)->getOpcode() == BinaryOperator::Add);
4560   assert(cast<BinaryOperator>(NextIndVar)->getOperand(0) == IndVar);
4561   assert(cast<ConstantInt>(cast<BinaryOperator>(NextIndVar)->getOperand(1))
4562              ->isOne());
4563 
4564   Value *TripCount = getTripCount();
4565   assert(TripCount && "Loop trip count not found?");
4566   assert(IndVar->getType() == TripCount->getType() &&
4567          "Trip count and induction variable must have the same type");
4568 
4569   auto *CmpI = cast<CmpInst>(&Cond->front());
4570   assert(CmpI->getPredicate() == CmpInst::ICMP_ULT &&
4571          "Exit condition must be a signed less-than comparison");
4572   assert(CmpI->getOperand(0) == IndVar &&
4573          "Exit condition must compare the induction variable");
4574   assert(CmpI->getOperand(1) == TripCount &&
4575          "Exit condition must compare with the trip count");
4576 #endif
4577 }
4578 
invalidate()4579 void CanonicalLoopInfo::invalidate() {
4580   Header = nullptr;
4581   Cond = nullptr;
4582   Latch = nullptr;
4583   Exit = nullptr;
4584 }
4585