1 //===- OpenMPIRBuilder.cpp - Builder for LLVM-IR for OpenMP directives ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 /// \file
9 ///
10 /// This file implements the OpenMPIRBuilder class, which is used as a
11 /// convenient way to create LLVM instructions for OpenMP directives.
12 ///
13 //===----------------------------------------------------------------------===//
14
15 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
16 #include "llvm/ADT/SmallSet.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "llvm/Analysis/AssumptionCache.h"
19 #include "llvm/Analysis/CodeMetrics.h"
20 #include "llvm/Analysis/LoopInfo.h"
21 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
22 #include "llvm/Analysis/ScalarEvolution.h"
23 #include "llvm/Analysis/TargetLibraryInfo.h"
24 #include "llvm/IR/CFG.h"
25 #include "llvm/IR/Constants.h"
26 #include "llvm/IR/DebugInfoMetadata.h"
27 #include "llvm/IR/GlobalVariable.h"
28 #include "llvm/IR/IRBuilder.h"
29 #include "llvm/IR/MDBuilder.h"
30 #include "llvm/IR/PassManager.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/MC/TargetRegistry.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Target/TargetMachine.h"
35 #include "llvm/Target/TargetOptions.h"
36 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
37 #include "llvm/Transforms/Utils/CodeExtractor.h"
38 #include "llvm/Transforms/Utils/LoopPeel.h"
39 #include "llvm/Transforms/Utils/UnrollLoop.h"
40
41 #include <cstdint>
42
43 #define DEBUG_TYPE "openmp-ir-builder"
44
45 using namespace llvm;
46 using namespace omp;
47
48 static cl::opt<bool>
49 OptimisticAttributes("openmp-ir-builder-optimistic-attributes", cl::Hidden,
50 cl::desc("Use optimistic attributes describing "
51 "'as-if' properties of runtime calls."),
52 cl::init(false));
53
54 static cl::opt<double> UnrollThresholdFactor(
55 "openmp-ir-builder-unroll-threshold-factor", cl::Hidden,
56 cl::desc("Factor for the unroll threshold to account for code "
57 "simplifications still taking place"),
58 cl::init(1.5));
59
60 #ifndef NDEBUG
61 /// Return whether IP1 and IP2 are ambiguous, i.e. that inserting instructions
62 /// at position IP1 may change the meaning of IP2 or vice-versa. This is because
63 /// an InsertPoint stores the instruction before something is inserted. For
64 /// instance, if both point to the same instruction, two IRBuilders alternating
65 /// creating instruction will cause the instructions to be interleaved.
isConflictIP(IRBuilder<>::InsertPoint IP1,IRBuilder<>::InsertPoint IP2)66 static bool isConflictIP(IRBuilder<>::InsertPoint IP1,
67 IRBuilder<>::InsertPoint IP2) {
68 if (!IP1.isSet() || !IP2.isSet())
69 return false;
70 return IP1.getBlock() == IP2.getBlock() && IP1.getPoint() == IP2.getPoint();
71 }
72
isValidWorkshareLoopScheduleType(OMPScheduleType SchedType)73 static bool isValidWorkshareLoopScheduleType(OMPScheduleType SchedType) {
74 // Valid ordered/unordered and base algorithm combinations.
75 switch (SchedType & ~OMPScheduleType::MonotonicityMask) {
76 case OMPScheduleType::UnorderedStaticChunked:
77 case OMPScheduleType::UnorderedStatic:
78 case OMPScheduleType::UnorderedDynamicChunked:
79 case OMPScheduleType::UnorderedGuidedChunked:
80 case OMPScheduleType::UnorderedRuntime:
81 case OMPScheduleType::UnorderedAuto:
82 case OMPScheduleType::UnorderedTrapezoidal:
83 case OMPScheduleType::UnorderedGreedy:
84 case OMPScheduleType::UnorderedBalanced:
85 case OMPScheduleType::UnorderedGuidedIterativeChunked:
86 case OMPScheduleType::UnorderedGuidedAnalyticalChunked:
87 case OMPScheduleType::UnorderedSteal:
88 case OMPScheduleType::UnorderedStaticBalancedChunked:
89 case OMPScheduleType::UnorderedGuidedSimd:
90 case OMPScheduleType::UnorderedRuntimeSimd:
91 case OMPScheduleType::OrderedStaticChunked:
92 case OMPScheduleType::OrderedStatic:
93 case OMPScheduleType::OrderedDynamicChunked:
94 case OMPScheduleType::OrderedGuidedChunked:
95 case OMPScheduleType::OrderedRuntime:
96 case OMPScheduleType::OrderedAuto:
97 case OMPScheduleType::OrderdTrapezoidal:
98 case OMPScheduleType::NomergeUnorderedStaticChunked:
99 case OMPScheduleType::NomergeUnorderedStatic:
100 case OMPScheduleType::NomergeUnorderedDynamicChunked:
101 case OMPScheduleType::NomergeUnorderedGuidedChunked:
102 case OMPScheduleType::NomergeUnorderedRuntime:
103 case OMPScheduleType::NomergeUnorderedAuto:
104 case OMPScheduleType::NomergeUnorderedTrapezoidal:
105 case OMPScheduleType::NomergeUnorderedGreedy:
106 case OMPScheduleType::NomergeUnorderedBalanced:
107 case OMPScheduleType::NomergeUnorderedGuidedIterativeChunked:
108 case OMPScheduleType::NomergeUnorderedGuidedAnalyticalChunked:
109 case OMPScheduleType::NomergeUnorderedSteal:
110 case OMPScheduleType::NomergeOrderedStaticChunked:
111 case OMPScheduleType::NomergeOrderedStatic:
112 case OMPScheduleType::NomergeOrderedDynamicChunked:
113 case OMPScheduleType::NomergeOrderedGuidedChunked:
114 case OMPScheduleType::NomergeOrderedRuntime:
115 case OMPScheduleType::NomergeOrderedAuto:
116 case OMPScheduleType::NomergeOrderedTrapezoidal:
117 break;
118 default:
119 return false;
120 }
121
122 // Must not set both monotonicity modifiers at the same time.
123 OMPScheduleType MonotonicityFlags =
124 SchedType & OMPScheduleType::MonotonicityMask;
125 if (MonotonicityFlags == OMPScheduleType::MonotonicityMask)
126 return false;
127
128 return true;
129 }
130 #endif
131
132 /// Determine which scheduling algorithm to use, determined from schedule clause
133 /// arguments.
134 static OMPScheduleType
getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind,bool HasChunks,bool HasSimdModifier)135 getOpenMPBaseScheduleType(llvm::omp::ScheduleKind ClauseKind, bool HasChunks,
136 bool HasSimdModifier) {
137 // Currently, the default schedule it static.
138 switch (ClauseKind) {
139 case OMP_SCHEDULE_Default:
140 case OMP_SCHEDULE_Static:
141 return HasChunks ? OMPScheduleType::BaseStaticChunked
142 : OMPScheduleType::BaseStatic;
143 case OMP_SCHEDULE_Dynamic:
144 return OMPScheduleType::BaseDynamicChunked;
145 case OMP_SCHEDULE_Guided:
146 return HasSimdModifier ? OMPScheduleType::BaseGuidedSimd
147 : OMPScheduleType::BaseGuidedChunked;
148 case OMP_SCHEDULE_Auto:
149 return llvm::omp::OMPScheduleType::BaseAuto;
150 case OMP_SCHEDULE_Runtime:
151 return HasSimdModifier ? OMPScheduleType::BaseRuntimeSimd
152 : OMPScheduleType::BaseRuntime;
153 }
154 llvm_unreachable("unhandled schedule clause argument");
155 }
156
157 /// Adds ordering modifier flags to schedule type.
158 static OMPScheduleType
getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,bool HasOrderedClause)159 getOpenMPOrderingScheduleType(OMPScheduleType BaseScheduleType,
160 bool HasOrderedClause) {
161 assert((BaseScheduleType & OMPScheduleType::ModifierMask) ==
162 OMPScheduleType::None &&
163 "Must not have ordering nor monotonicity flags already set");
164
165 OMPScheduleType OrderingModifier = HasOrderedClause
166 ? OMPScheduleType::ModifierOrdered
167 : OMPScheduleType::ModifierUnordered;
168 OMPScheduleType OrderingScheduleType = BaseScheduleType | OrderingModifier;
169
170 // Unsupported combinations
171 if (OrderingScheduleType ==
172 (OMPScheduleType::BaseGuidedSimd | OMPScheduleType::ModifierOrdered))
173 return OMPScheduleType::OrderedGuidedChunked;
174 else if (OrderingScheduleType == (OMPScheduleType::BaseRuntimeSimd |
175 OMPScheduleType::ModifierOrdered))
176 return OMPScheduleType::OrderedRuntime;
177
178 return OrderingScheduleType;
179 }
180
181 /// Adds monotonicity modifier flags to schedule type.
182 static OMPScheduleType
getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,bool HasSimdModifier,bool HasMonotonic,bool HasNonmonotonic,bool HasOrderedClause)183 getOpenMPMonotonicityScheduleType(OMPScheduleType ScheduleType,
184 bool HasSimdModifier, bool HasMonotonic,
185 bool HasNonmonotonic, bool HasOrderedClause) {
186 assert((ScheduleType & OMPScheduleType::MonotonicityMask) ==
187 OMPScheduleType::None &&
188 "Must not have monotonicity flags already set");
189 assert((!HasMonotonic || !HasNonmonotonic) &&
190 "Monotonic and Nonmonotonic are contradicting each other");
191
192 if (HasMonotonic) {
193 return ScheduleType | OMPScheduleType::ModifierMonotonic;
194 } else if (HasNonmonotonic) {
195 return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
196 } else {
197 // OpenMP 5.1, 2.11.4 Worksharing-Loop Construct, Description.
198 // If the static schedule kind is specified or if the ordered clause is
199 // specified, and if the nonmonotonic modifier is not specified, the
200 // effect is as if the monotonic modifier is specified. Otherwise, unless
201 // the monotonic modifier is specified, the effect is as if the
202 // nonmonotonic modifier is specified.
203 OMPScheduleType BaseScheduleType =
204 ScheduleType & ~OMPScheduleType::ModifierMask;
205 if ((BaseScheduleType == OMPScheduleType::BaseStatic) ||
206 (BaseScheduleType == OMPScheduleType::BaseStaticChunked) ||
207 HasOrderedClause) {
208 // The monotonic is used by default in openmp runtime library, so no need
209 // to set it.
210 return ScheduleType;
211 } else {
212 return ScheduleType | OMPScheduleType::ModifierNonmonotonic;
213 }
214 }
215 }
216
217 /// Determine the schedule type using schedule and ordering clause arguments.
218 static OMPScheduleType
computeOpenMPScheduleType(ScheduleKind ClauseKind,bool HasChunks,bool HasSimdModifier,bool HasMonotonicModifier,bool HasNonmonotonicModifier,bool HasOrderedClause)219 computeOpenMPScheduleType(ScheduleKind ClauseKind, bool HasChunks,
220 bool HasSimdModifier, bool HasMonotonicModifier,
221 bool HasNonmonotonicModifier, bool HasOrderedClause) {
222 OMPScheduleType BaseSchedule =
223 getOpenMPBaseScheduleType(ClauseKind, HasChunks, HasSimdModifier);
224 OMPScheduleType OrderedSchedule =
225 getOpenMPOrderingScheduleType(BaseSchedule, HasOrderedClause);
226 OMPScheduleType Result = getOpenMPMonotonicityScheduleType(
227 OrderedSchedule, HasSimdModifier, HasMonotonicModifier,
228 HasNonmonotonicModifier, HasOrderedClause);
229
230 assert(isValidWorkshareLoopScheduleType(Result));
231 return Result;
232 }
233
234 /// Make \p Source branch to \p Target.
235 ///
236 /// Handles two situations:
237 /// * \p Source already has an unconditional branch.
238 /// * \p Source is a degenerate block (no terminator because the BB is
239 /// the current head of the IR construction).
redirectTo(BasicBlock * Source,BasicBlock * Target,DebugLoc DL)240 static void redirectTo(BasicBlock *Source, BasicBlock *Target, DebugLoc DL) {
241 if (Instruction *Term = Source->getTerminator()) {
242 auto *Br = cast<BranchInst>(Term);
243 assert(!Br->isConditional() &&
244 "BB's terminator must be an unconditional branch (or degenerate)");
245 BasicBlock *Succ = Br->getSuccessor(0);
246 Succ->removePredecessor(Source, /*KeepOneInputPHIs=*/true);
247 Br->setSuccessor(0, Target);
248 return;
249 }
250
251 auto *NewBr = BranchInst::Create(Target, Source);
252 NewBr->setDebugLoc(DL);
253 }
254
spliceBB(IRBuilderBase::InsertPoint IP,BasicBlock * New,bool CreateBranch)255 void llvm::spliceBB(IRBuilderBase::InsertPoint IP, BasicBlock *New,
256 bool CreateBranch) {
257 assert(New->getFirstInsertionPt() == New->begin() &&
258 "Target BB must not have PHI nodes");
259
260 // Move instructions to new block.
261 BasicBlock *Old = IP.getBlock();
262 New->getInstList().splice(New->begin(), Old->getInstList(), IP.getPoint(),
263 Old->end());
264
265 if (CreateBranch)
266 BranchInst::Create(New, Old);
267 }
268
spliceBB(IRBuilder<> & Builder,BasicBlock * New,bool CreateBranch)269 void llvm::spliceBB(IRBuilder<> &Builder, BasicBlock *New, bool CreateBranch) {
270 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
271 BasicBlock *Old = Builder.GetInsertBlock();
272
273 spliceBB(Builder.saveIP(), New, CreateBranch);
274 if (CreateBranch)
275 Builder.SetInsertPoint(Old->getTerminator());
276 else
277 Builder.SetInsertPoint(Old);
278
279 // SetInsertPoint also updates the Builder's debug location, but we want to
280 // keep the one the Builder was configured to use.
281 Builder.SetCurrentDebugLocation(DebugLoc);
282 }
283
splitBB(IRBuilderBase::InsertPoint IP,bool CreateBranch,llvm::Twine Name)284 BasicBlock *llvm::splitBB(IRBuilderBase::InsertPoint IP, bool CreateBranch,
285 llvm::Twine Name) {
286 BasicBlock *Old = IP.getBlock();
287 BasicBlock *New = BasicBlock::Create(
288 Old->getContext(), Name.isTriviallyEmpty() ? Old->getName() : Name,
289 Old->getParent(), Old->getNextNode());
290 spliceBB(IP, New, CreateBranch);
291 New->replaceSuccessorsPhiUsesWith(Old, New);
292 return New;
293 }
294
splitBB(IRBuilderBase & Builder,bool CreateBranch,llvm::Twine Name)295 BasicBlock *llvm::splitBB(IRBuilderBase &Builder, bool CreateBranch,
296 llvm::Twine Name) {
297 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
298 BasicBlock *New = splitBB(Builder.saveIP(), CreateBranch, Name);
299 if (CreateBranch)
300 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
301 else
302 Builder.SetInsertPoint(Builder.GetInsertBlock());
303 // SetInsertPoint also updates the Builder's debug location, but we want to
304 // keep the one the Builder was configured to use.
305 Builder.SetCurrentDebugLocation(DebugLoc);
306 return New;
307 }
308
splitBB(IRBuilder<> & Builder,bool CreateBranch,llvm::Twine Name)309 BasicBlock *llvm::splitBB(IRBuilder<> &Builder, bool CreateBranch,
310 llvm::Twine Name) {
311 DebugLoc DebugLoc = Builder.getCurrentDebugLocation();
312 BasicBlock *New = splitBB(Builder.saveIP(), CreateBranch, Name);
313 if (CreateBranch)
314 Builder.SetInsertPoint(Builder.GetInsertBlock()->getTerminator());
315 else
316 Builder.SetInsertPoint(Builder.GetInsertBlock());
317 // SetInsertPoint also updates the Builder's debug location, but we want to
318 // keep the one the Builder was configured to use.
319 Builder.SetCurrentDebugLocation(DebugLoc);
320 return New;
321 }
322
splitBBWithSuffix(IRBuilderBase & Builder,bool CreateBranch,llvm::Twine Suffix)323 BasicBlock *llvm::splitBBWithSuffix(IRBuilderBase &Builder, bool CreateBranch,
324 llvm::Twine Suffix) {
325 BasicBlock *Old = Builder.GetInsertBlock();
326 return splitBB(Builder, CreateBranch, Old->getName() + Suffix);
327 }
328
addAttributes(omp::RuntimeFunction FnID,Function & Fn)329 void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
330 LLVMContext &Ctx = Fn.getContext();
331
332 // Get the function's current attributes.
333 auto Attrs = Fn.getAttributes();
334 auto FnAttrs = Attrs.getFnAttrs();
335 auto RetAttrs = Attrs.getRetAttrs();
336 SmallVector<AttributeSet, 4> ArgAttrs;
337 for (size_t ArgNo = 0; ArgNo < Fn.arg_size(); ++ArgNo)
338 ArgAttrs.emplace_back(Attrs.getParamAttrs(ArgNo));
339
340 #define OMP_ATTRS_SET(VarName, AttrSet) AttributeSet VarName = AttrSet;
341 #include "llvm/Frontend/OpenMP/OMPKinds.def"
342
343 // Add attributes to the function declaration.
344 switch (FnID) {
345 #define OMP_RTL_ATTRS(Enum, FnAttrSet, RetAttrSet, ArgAttrSets) \
346 case Enum: \
347 FnAttrs = FnAttrs.addAttributes(Ctx, FnAttrSet); \
348 RetAttrs = RetAttrs.addAttributes(Ctx, RetAttrSet); \
349 for (size_t ArgNo = 0; ArgNo < ArgAttrSets.size(); ++ArgNo) \
350 ArgAttrs[ArgNo] = \
351 ArgAttrs[ArgNo].addAttributes(Ctx, ArgAttrSets[ArgNo]); \
352 Fn.setAttributes(AttributeList::get(Ctx, FnAttrs, RetAttrs, ArgAttrs)); \
353 break;
354 #include "llvm/Frontend/OpenMP/OMPKinds.def"
355 default:
356 // Attributes are optional.
357 break;
358 }
359 }
360
361 FunctionCallee
getOrCreateRuntimeFunction(Module & M,RuntimeFunction FnID)362 OpenMPIRBuilder::getOrCreateRuntimeFunction(Module &M, RuntimeFunction FnID) {
363 FunctionType *FnTy = nullptr;
364 Function *Fn = nullptr;
365
366 // Try to find the declation in the module first.
367 switch (FnID) {
368 #define OMP_RTL(Enum, Str, IsVarArg, ReturnType, ...) \
369 case Enum: \
370 FnTy = FunctionType::get(ReturnType, ArrayRef<Type *>{__VA_ARGS__}, \
371 IsVarArg); \
372 Fn = M.getFunction(Str); \
373 break;
374 #include "llvm/Frontend/OpenMP/OMPKinds.def"
375 }
376
377 if (!Fn) {
378 // Create a new declaration if we need one.
379 switch (FnID) {
380 #define OMP_RTL(Enum, Str, ...) \
381 case Enum: \
382 Fn = Function::Create(FnTy, GlobalValue::ExternalLinkage, Str, M); \
383 break;
384 #include "llvm/Frontend/OpenMP/OMPKinds.def"
385 }
386
387 // Add information if the runtime function takes a callback function
388 if (FnID == OMPRTL___kmpc_fork_call || FnID == OMPRTL___kmpc_fork_teams) {
389 if (!Fn->hasMetadata(LLVMContext::MD_callback)) {
390 LLVMContext &Ctx = Fn->getContext();
391 MDBuilder MDB(Ctx);
392 // Annotate the callback behavior of the runtime function:
393 // - The callback callee is argument number 2 (microtask).
394 // - The first two arguments of the callback callee are unknown (-1).
395 // - All variadic arguments to the runtime function are passed to the
396 // callback callee.
397 Fn->addMetadata(
398 LLVMContext::MD_callback,
399 *MDNode::get(Ctx, {MDB.createCallbackEncoding(
400 2, {-1, -1}, /* VarArgsArePassed */ true)}));
401 }
402 }
403
404 LLVM_DEBUG(dbgs() << "Created OpenMP runtime function " << Fn->getName()
405 << " with type " << *Fn->getFunctionType() << "\n");
406 addAttributes(FnID, *Fn);
407
408 } else {
409 LLVM_DEBUG(dbgs() << "Found OpenMP runtime function " << Fn->getName()
410 << " with type " << *Fn->getFunctionType() << "\n");
411 }
412
413 assert(Fn && "Failed to create OpenMP runtime function");
414
415 // Cast the function to the expected type if necessary
416 Constant *C = ConstantExpr::getBitCast(Fn, FnTy->getPointerTo());
417 return {FnTy, C};
418 }
419
getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID)420 Function *OpenMPIRBuilder::getOrCreateRuntimeFunctionPtr(RuntimeFunction FnID) {
421 FunctionCallee RTLFn = getOrCreateRuntimeFunction(M, FnID);
422 auto *Fn = dyn_cast<llvm::Function>(RTLFn.getCallee());
423 assert(Fn && "Failed to create OpenMP runtime function pointer");
424 return Fn;
425 }
426
initialize()427 void OpenMPIRBuilder::initialize() { initializeTypes(M); }
428
finalize(Function * Fn)429 void OpenMPIRBuilder::finalize(Function *Fn) {
430 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
431 SmallVector<BasicBlock *, 32> Blocks;
432 SmallVector<OutlineInfo, 16> DeferredOutlines;
433 for (OutlineInfo &OI : OutlineInfos) {
434 // Skip functions that have not finalized yet; may happen with nested
435 // function generation.
436 if (Fn && OI.getFunction() != Fn) {
437 DeferredOutlines.push_back(OI);
438 continue;
439 }
440
441 ParallelRegionBlockSet.clear();
442 Blocks.clear();
443 OI.collectBlocks(ParallelRegionBlockSet, Blocks);
444
445 Function *OuterFn = OI.getFunction();
446 CodeExtractorAnalysisCache CEAC(*OuterFn);
447 CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
448 /* AggregateArgs */ true,
449 /* BlockFrequencyInfo */ nullptr,
450 /* BranchProbabilityInfo */ nullptr,
451 /* AssumptionCache */ nullptr,
452 /* AllowVarArgs */ true,
453 /* AllowAlloca */ true,
454 /* AllocaBlock*/ OI.OuterAllocaBB,
455 /* Suffix */ ".omp_par");
456
457 LLVM_DEBUG(dbgs() << "Before outlining: " << *OuterFn << "\n");
458 LLVM_DEBUG(dbgs() << "Entry " << OI.EntryBB->getName()
459 << " Exit: " << OI.ExitBB->getName() << "\n");
460 assert(Extractor.isEligible() &&
461 "Expected OpenMP outlining to be possible!");
462
463 for (auto *V : OI.ExcludeArgsFromAggregate)
464 Extractor.excludeArgFromAggregate(V);
465
466 Function *OutlinedFn = Extractor.extractCodeRegion(CEAC);
467
468 LLVM_DEBUG(dbgs() << "After outlining: " << *OuterFn << "\n");
469 LLVM_DEBUG(dbgs() << " Outlined function: " << *OutlinedFn << "\n");
470 assert(OutlinedFn->getReturnType()->isVoidTy() &&
471 "OpenMP outlined functions should not return a value!");
472
473 // For compability with the clang CG we move the outlined function after the
474 // one with the parallel region.
475 OutlinedFn->removeFromParent();
476 M.getFunctionList().insertAfter(OuterFn->getIterator(), OutlinedFn);
477
478 // Remove the artificial entry introduced by the extractor right away, we
479 // made our own entry block after all.
480 {
481 BasicBlock &ArtificialEntry = OutlinedFn->getEntryBlock();
482 assert(ArtificialEntry.getUniqueSuccessor() == OI.EntryBB);
483 assert(OI.EntryBB->getUniquePredecessor() == &ArtificialEntry);
484 // Move instructions from the to-be-deleted ArtificialEntry to the entry
485 // basic block of the parallel region. CodeExtractor generates
486 // instructions to unwrap the aggregate argument and may sink
487 // allocas/bitcasts for values that are solely used in the outlined region
488 // and do not escape.
489 assert(!ArtificialEntry.empty() &&
490 "Expected instructions to add in the outlined region entry");
491 for (BasicBlock::reverse_iterator It = ArtificialEntry.rbegin(),
492 End = ArtificialEntry.rend();
493 It != End;) {
494 Instruction &I = *It;
495 It++;
496
497 if (I.isTerminator())
498 continue;
499
500 I.moveBefore(*OI.EntryBB, OI.EntryBB->getFirstInsertionPt());
501 }
502
503 OI.EntryBB->moveBefore(&ArtificialEntry);
504 ArtificialEntry.eraseFromParent();
505 }
506 assert(&OutlinedFn->getEntryBlock() == OI.EntryBB);
507 assert(OutlinedFn && OutlinedFn->getNumUses() == 1);
508
509 // Run a user callback, e.g. to add attributes.
510 if (OI.PostOutlineCB)
511 OI.PostOutlineCB(*OutlinedFn);
512 }
513
514 // Remove work items that have been completed.
515 OutlineInfos = std::move(DeferredOutlines);
516 }
517
~OpenMPIRBuilder()518 OpenMPIRBuilder::~OpenMPIRBuilder() {
519 assert(OutlineInfos.empty() && "There must be no outstanding outlinings");
520 }
521
createGlobalFlag(unsigned Value,StringRef Name)522 GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
523 IntegerType *I32Ty = Type::getInt32Ty(M.getContext());
524 auto *GV =
525 new GlobalVariable(M, I32Ty,
526 /* isConstant = */ true, GlobalValue::WeakODRLinkage,
527 ConstantInt::get(I32Ty, Value), Name);
528 GV->setVisibility(GlobalValue::HiddenVisibility);
529
530 return GV;
531 }
532
getOrCreateIdent(Constant * SrcLocStr,uint32_t SrcLocStrSize,IdentFlag LocFlags,unsigned Reserve2Flags)533 Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
534 uint32_t SrcLocStrSize,
535 IdentFlag LocFlags,
536 unsigned Reserve2Flags) {
537 // Enable "C-mode".
538 LocFlags |= OMP_IDENT_FLAG_KMPC;
539
540 Constant *&Ident =
541 IdentMap[{SrcLocStr, uint64_t(LocFlags) << 31 | Reserve2Flags}];
542 if (!Ident) {
543 Constant *I32Null = ConstantInt::getNullValue(Int32);
544 Constant *IdentData[] = {I32Null,
545 ConstantInt::get(Int32, uint32_t(LocFlags)),
546 ConstantInt::get(Int32, Reserve2Flags),
547 ConstantInt::get(Int32, SrcLocStrSize), SrcLocStr};
548 Constant *Initializer =
549 ConstantStruct::get(OpenMPIRBuilder::Ident, IdentData);
550
551 // Look for existing encoding of the location + flags, not needed but
552 // minimizes the difference to the existing solution while we transition.
553 for (GlobalVariable &GV : M.getGlobalList())
554 if (GV.getValueType() == OpenMPIRBuilder::Ident && GV.hasInitializer())
555 if (GV.getInitializer() == Initializer)
556 Ident = &GV;
557
558 if (!Ident) {
559 auto *GV = new GlobalVariable(
560 M, OpenMPIRBuilder::Ident,
561 /* isConstant = */ true, GlobalValue::PrivateLinkage, Initializer, "",
562 nullptr, GlobalValue::NotThreadLocal,
563 M.getDataLayout().getDefaultGlobalsAddressSpace());
564 GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global);
565 GV->setAlignment(Align(8));
566 Ident = GV;
567 }
568 }
569
570 return ConstantExpr::getPointerBitCastOrAddrSpaceCast(Ident, IdentPtr);
571 }
572
getOrCreateSrcLocStr(StringRef LocStr,uint32_t & SrcLocStrSize)573 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef LocStr,
574 uint32_t &SrcLocStrSize) {
575 SrcLocStrSize = LocStr.size();
576 Constant *&SrcLocStr = SrcLocStrMap[LocStr];
577 if (!SrcLocStr) {
578 Constant *Initializer =
579 ConstantDataArray::getString(M.getContext(), LocStr);
580
581 // Look for existing encoding of the location, not needed but minimizes the
582 // difference to the existing solution while we transition.
583 for (GlobalVariable &GV : M.getGlobalList())
584 if (GV.isConstant() && GV.hasInitializer() &&
585 GV.getInitializer() == Initializer)
586 return SrcLocStr = ConstantExpr::getPointerCast(&GV, Int8Ptr);
587
588 SrcLocStr = Builder.CreateGlobalStringPtr(LocStr, /* Name */ "",
589 /* AddressSpace */ 0, &M);
590 }
591 return SrcLocStr;
592 }
593
getOrCreateSrcLocStr(StringRef FunctionName,StringRef FileName,unsigned Line,unsigned Column,uint32_t & SrcLocStrSize)594 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(StringRef FunctionName,
595 StringRef FileName,
596 unsigned Line, unsigned Column,
597 uint32_t &SrcLocStrSize) {
598 SmallString<128> Buffer;
599 Buffer.push_back(';');
600 Buffer.append(FileName);
601 Buffer.push_back(';');
602 Buffer.append(FunctionName);
603 Buffer.push_back(';');
604 Buffer.append(std::to_string(Line));
605 Buffer.push_back(';');
606 Buffer.append(std::to_string(Column));
607 Buffer.push_back(';');
608 Buffer.push_back(';');
609 return getOrCreateSrcLocStr(Buffer.str(), SrcLocStrSize);
610 }
611
612 Constant *
getOrCreateDefaultSrcLocStr(uint32_t & SrcLocStrSize)613 OpenMPIRBuilder::getOrCreateDefaultSrcLocStr(uint32_t &SrcLocStrSize) {
614 StringRef UnknownLoc = ";unknown;unknown;0;0;;";
615 return getOrCreateSrcLocStr(UnknownLoc, SrcLocStrSize);
616 }
617
getOrCreateSrcLocStr(DebugLoc DL,uint32_t & SrcLocStrSize,Function * F)618 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(DebugLoc DL,
619 uint32_t &SrcLocStrSize,
620 Function *F) {
621 DILocation *DIL = DL.get();
622 if (!DIL)
623 return getOrCreateDefaultSrcLocStr(SrcLocStrSize);
624 StringRef FileName = M.getName();
625 if (DIFile *DIF = DIL->getFile())
626 if (Optional<StringRef> Source = DIF->getSource())
627 FileName = *Source;
628 StringRef Function = DIL->getScope()->getSubprogram()->getName();
629 if (Function.empty() && F)
630 Function = F->getName();
631 return getOrCreateSrcLocStr(Function, FileName, DIL->getLine(),
632 DIL->getColumn(), SrcLocStrSize);
633 }
634
getOrCreateSrcLocStr(const LocationDescription & Loc,uint32_t & SrcLocStrSize)635 Constant *OpenMPIRBuilder::getOrCreateSrcLocStr(const LocationDescription &Loc,
636 uint32_t &SrcLocStrSize) {
637 return getOrCreateSrcLocStr(Loc.DL, SrcLocStrSize,
638 Loc.IP.getBlock()->getParent());
639 }
640
getOrCreateThreadID(Value * Ident)641 Value *OpenMPIRBuilder::getOrCreateThreadID(Value *Ident) {
642 return Builder.CreateCall(
643 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num), Ident,
644 "omp_global_thread_num");
645 }
646
647 OpenMPIRBuilder::InsertPointTy
createBarrier(const LocationDescription & Loc,Directive DK,bool ForceSimpleCall,bool CheckCancelFlag)648 OpenMPIRBuilder::createBarrier(const LocationDescription &Loc, Directive DK,
649 bool ForceSimpleCall, bool CheckCancelFlag) {
650 if (!updateToLocation(Loc))
651 return Loc.IP;
652 return emitBarrierImpl(Loc, DK, ForceSimpleCall, CheckCancelFlag);
653 }
654
655 OpenMPIRBuilder::InsertPointTy
emitBarrierImpl(const LocationDescription & Loc,Directive Kind,bool ForceSimpleCall,bool CheckCancelFlag)656 OpenMPIRBuilder::emitBarrierImpl(const LocationDescription &Loc, Directive Kind,
657 bool ForceSimpleCall, bool CheckCancelFlag) {
658 // Build call __kmpc_cancel_barrier(loc, thread_id) or
659 // __kmpc_barrier(loc, thread_id);
660
661 IdentFlag BarrierLocFlags;
662 switch (Kind) {
663 case OMPD_for:
664 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_FOR;
665 break;
666 case OMPD_sections:
667 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SECTIONS;
668 break;
669 case OMPD_single:
670 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL_SINGLE;
671 break;
672 case OMPD_barrier:
673 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_EXPL;
674 break;
675 default:
676 BarrierLocFlags = OMP_IDENT_FLAG_BARRIER_IMPL;
677 break;
678 }
679
680 uint32_t SrcLocStrSize;
681 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
682 Value *Args[] = {
683 getOrCreateIdent(SrcLocStr, SrcLocStrSize, BarrierLocFlags),
684 getOrCreateThreadID(getOrCreateIdent(SrcLocStr, SrcLocStrSize))};
685
686 // If we are in a cancellable parallel region, barriers are cancellation
687 // points.
688 // TODO: Check why we would force simple calls or to ignore the cancel flag.
689 bool UseCancelBarrier =
690 !ForceSimpleCall && isLastFinalizationInfoCancellable(OMPD_parallel);
691
692 Value *Result =
693 Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
694 UseCancelBarrier ? OMPRTL___kmpc_cancel_barrier
695 : OMPRTL___kmpc_barrier),
696 Args);
697
698 if (UseCancelBarrier && CheckCancelFlag)
699 emitCancelationCheckImpl(Result, OMPD_parallel);
700
701 return Builder.saveIP();
702 }
703
704 OpenMPIRBuilder::InsertPointTy
createCancel(const LocationDescription & Loc,Value * IfCondition,omp::Directive CanceledDirective)705 OpenMPIRBuilder::createCancel(const LocationDescription &Loc,
706 Value *IfCondition,
707 omp::Directive CanceledDirective) {
708 if (!updateToLocation(Loc))
709 return Loc.IP;
710
711 // LLVM utilities like blocks with terminators.
712 auto *UI = Builder.CreateUnreachable();
713
714 Instruction *ThenTI = UI, *ElseTI = nullptr;
715 if (IfCondition)
716 SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI);
717 Builder.SetInsertPoint(ThenTI);
718
719 Value *CancelKind = nullptr;
720 switch (CanceledDirective) {
721 #define OMP_CANCEL_KIND(Enum, Str, DirectiveEnum, Value) \
722 case DirectiveEnum: \
723 CancelKind = Builder.getInt32(Value); \
724 break;
725 #include "llvm/Frontend/OpenMP/OMPKinds.def"
726 default:
727 llvm_unreachable("Unknown cancel kind!");
728 }
729
730 uint32_t SrcLocStrSize;
731 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
732 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
733 Value *Args[] = {Ident, getOrCreateThreadID(Ident), CancelKind};
734 Value *Result = Builder.CreateCall(
735 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_cancel), Args);
736 auto ExitCB = [this, CanceledDirective, Loc](InsertPointTy IP) {
737 if (CanceledDirective == OMPD_parallel) {
738 IRBuilder<>::InsertPointGuard IPG(Builder);
739 Builder.restoreIP(IP);
740 createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
741 omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
742 /* CheckCancelFlag */ false);
743 }
744 };
745
746 // The actual cancel logic is shared with others, e.g., cancel_barriers.
747 emitCancelationCheckImpl(Result, CanceledDirective, ExitCB);
748
749 // Update the insertion point and remove the terminator we introduced.
750 Builder.SetInsertPoint(UI->getParent());
751 UI->eraseFromParent();
752
753 return Builder.saveIP();
754 }
755
emitOffloadingEntry(Constant * Addr,StringRef Name,uint64_t Size,int32_t Flags,StringRef SectionName)756 void OpenMPIRBuilder::emitOffloadingEntry(Constant *Addr, StringRef Name,
757 uint64_t Size, int32_t Flags,
758 StringRef SectionName) {
759 Type *Int8PtrTy = Type::getInt8PtrTy(M.getContext());
760 Type *Int32Ty = Type::getInt32Ty(M.getContext());
761 Type *SizeTy = M.getDataLayout().getIntPtrType(M.getContext());
762
763 Constant *AddrName = ConstantDataArray::getString(M.getContext(), Name);
764
765 // Create the constant string used to look up the symbol in the device.
766 auto *Str =
767 new llvm::GlobalVariable(M, AddrName->getType(), /*isConstant=*/true,
768 llvm::GlobalValue::InternalLinkage, AddrName,
769 ".omp_offloading.entry_name");
770 Str->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
771
772 // Construct the offloading entry.
773 Constant *EntryData[] = {
774 ConstantExpr::getPointerBitCastOrAddrSpaceCast(Addr, Int8PtrTy),
775 ConstantExpr::getPointerBitCastOrAddrSpaceCast(Str, Int8PtrTy),
776 ConstantInt::get(SizeTy, Size),
777 ConstantInt::get(Int32Ty, Flags),
778 ConstantInt::get(Int32Ty, 0),
779 };
780 Constant *EntryInitializer =
781 ConstantStruct::get(OpenMPIRBuilder::OffloadEntry, EntryData);
782
783 auto *Entry = new GlobalVariable(
784 M, OpenMPIRBuilder::OffloadEntry,
785 /* isConstant = */ true, GlobalValue::WeakAnyLinkage, EntryInitializer,
786 ".omp_offloading.entry." + Name, nullptr, GlobalValue::NotThreadLocal,
787 M.getDataLayout().getDefaultGlobalsAddressSpace());
788
789 // The entry has to be created in the section the linker expects it to be.
790 Entry->setSection(SectionName);
791 Entry->setAlignment(Align(1));
792 }
793
emitTargetKernel(const LocationDescription & Loc,Value * & Return,Value * Ident,Value * DeviceID,Value * NumTeams,Value * NumThreads,Value * HostPtr,ArrayRef<Value * > KernelArgs,ArrayRef<Value * > NoWaitArgs)794 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetKernel(
795 const LocationDescription &Loc, Value *&Return, Value *Ident,
796 Value *DeviceID, Value *NumTeams, Value *NumThreads, Value *HostPtr,
797 ArrayRef<Value *> KernelArgs, ArrayRef<Value *> NoWaitArgs) {
798 if (!updateToLocation(Loc))
799 return Loc.IP;
800
801 auto *KernelArgsPtr =
802 Builder.CreateAlloca(OpenMPIRBuilder::KernelArgs, nullptr, "kernel_args");
803 for (unsigned I = 0, Size = KernelArgs.size(); I != Size; ++I) {
804 llvm::Value *Arg =
805 Builder.CreateStructGEP(OpenMPIRBuilder::KernelArgs, KernelArgsPtr, I);
806 Builder.CreateAlignedStore(
807 KernelArgs[I], Arg,
808 M.getDataLayout().getPrefTypeAlign(KernelArgs[I]->getType()));
809 }
810
811 bool HasNoWait = !NoWaitArgs.empty();
812 SmallVector<Value *> OffloadingArgs{Ident, DeviceID, NumTeams,
813 NumThreads, HostPtr, KernelArgsPtr};
814 if (HasNoWait)
815 OffloadingArgs.append(NoWaitArgs.begin(), NoWaitArgs.end());
816
817 Return = Builder.CreateCall(
818 HasNoWait
819 ? getOrCreateRuntimeFunction(M, OMPRTL___tgt_target_kernel_nowait)
820 : getOrCreateRuntimeFunction(M, OMPRTL___tgt_target_kernel),
821 OffloadingArgs);
822
823 return Builder.saveIP();
824 }
825
emitCancelationCheckImpl(Value * CancelFlag,omp::Directive CanceledDirective,FinalizeCallbackTy ExitCB)826 void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag,
827 omp::Directive CanceledDirective,
828 FinalizeCallbackTy ExitCB) {
829 assert(isLastFinalizationInfoCancellable(CanceledDirective) &&
830 "Unexpected cancellation!");
831
832 // For a cancel barrier we create two new blocks.
833 BasicBlock *BB = Builder.GetInsertBlock();
834 BasicBlock *NonCancellationBlock;
835 if (Builder.GetInsertPoint() == BB->end()) {
836 // TODO: This branch will not be needed once we moved to the
837 // OpenMPIRBuilder codegen completely.
838 NonCancellationBlock = BasicBlock::Create(
839 BB->getContext(), BB->getName() + ".cont", BB->getParent());
840 } else {
841 NonCancellationBlock = SplitBlock(BB, &*Builder.GetInsertPoint());
842 BB->getTerminator()->eraseFromParent();
843 Builder.SetInsertPoint(BB);
844 }
845 BasicBlock *CancellationBlock = BasicBlock::Create(
846 BB->getContext(), BB->getName() + ".cncl", BB->getParent());
847
848 // Jump to them based on the return value.
849 Value *Cmp = Builder.CreateIsNull(CancelFlag);
850 Builder.CreateCondBr(Cmp, NonCancellationBlock, CancellationBlock,
851 /* TODO weight */ nullptr, nullptr);
852
853 // From the cancellation block we finalize all variables and go to the
854 // post finalization block that is known to the FiniCB callback.
855 Builder.SetInsertPoint(CancellationBlock);
856 if (ExitCB)
857 ExitCB(Builder.saveIP());
858 auto &FI = FinalizationStack.back();
859 FI.FiniCB(Builder.saveIP());
860
861 // The continuation block is where code generation continues.
862 Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
863 }
864
createParallel(const LocationDescription & Loc,InsertPointTy OuterAllocaIP,BodyGenCallbackTy BodyGenCB,PrivatizeCallbackTy PrivCB,FinalizeCallbackTy FiniCB,Value * IfCondition,Value * NumThreads,omp::ProcBindKind ProcBind,bool IsCancellable)865 IRBuilder<>::InsertPoint OpenMPIRBuilder::createParallel(
866 const LocationDescription &Loc, InsertPointTy OuterAllocaIP,
867 BodyGenCallbackTy BodyGenCB, PrivatizeCallbackTy PrivCB,
868 FinalizeCallbackTy FiniCB, Value *IfCondition, Value *NumThreads,
869 omp::ProcBindKind ProcBind, bool IsCancellable) {
870 assert(!isConflictIP(Loc.IP, OuterAllocaIP) && "IPs must not be ambiguous");
871
872 if (!updateToLocation(Loc))
873 return Loc.IP;
874
875 uint32_t SrcLocStrSize;
876 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
877 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
878 Value *ThreadID = getOrCreateThreadID(Ident);
879
880 if (NumThreads) {
881 // Build call __kmpc_push_num_threads(&Ident, global_tid, num_threads)
882 Value *Args[] = {
883 Ident, ThreadID,
884 Builder.CreateIntCast(NumThreads, Int32, /*isSigned*/ false)};
885 Builder.CreateCall(
886 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_num_threads), Args);
887 }
888
889 if (ProcBind != OMP_PROC_BIND_default) {
890 // Build call __kmpc_push_proc_bind(&Ident, global_tid, proc_bind)
891 Value *Args[] = {
892 Ident, ThreadID,
893 ConstantInt::get(Int32, unsigned(ProcBind), /*isSigned=*/true)};
894 Builder.CreateCall(
895 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_push_proc_bind), Args);
896 }
897
898 BasicBlock *InsertBB = Builder.GetInsertBlock();
899 Function *OuterFn = InsertBB->getParent();
900
901 // Save the outer alloca block because the insertion iterator may get
902 // invalidated and we still need this later.
903 BasicBlock *OuterAllocaBlock = OuterAllocaIP.getBlock();
904
905 // Vector to remember instructions we used only during the modeling but which
906 // we want to delete at the end.
907 SmallVector<Instruction *, 4> ToBeDeleted;
908
909 // Change the location to the outer alloca insertion point to create and
910 // initialize the allocas we pass into the parallel region.
911 Builder.restoreIP(OuterAllocaIP);
912 AllocaInst *TIDAddr = Builder.CreateAlloca(Int32, nullptr, "tid.addr");
913 AllocaInst *ZeroAddr = Builder.CreateAlloca(Int32, nullptr, "zero.addr");
914
915 // If there is an if condition we actually use the TIDAddr and ZeroAddr in the
916 // program, otherwise we only need them for modeling purposes to get the
917 // associated arguments in the outlined function. In the former case,
918 // initialize the allocas properly, in the latter case, delete them later.
919 if (IfCondition) {
920 Builder.CreateStore(Constant::getNullValue(Int32), TIDAddr);
921 Builder.CreateStore(Constant::getNullValue(Int32), ZeroAddr);
922 } else {
923 ToBeDeleted.push_back(TIDAddr);
924 ToBeDeleted.push_back(ZeroAddr);
925 }
926
927 // Create an artificial insertion point that will also ensure the blocks we
928 // are about to split are not degenerated.
929 auto *UI = new UnreachableInst(Builder.getContext(), InsertBB);
930
931 Instruction *ThenTI = UI, *ElseTI = nullptr;
932 if (IfCondition)
933 SplitBlockAndInsertIfThenElse(IfCondition, UI, &ThenTI, &ElseTI);
934
935 BasicBlock *ThenBB = ThenTI->getParent();
936 BasicBlock *PRegEntryBB = ThenBB->splitBasicBlock(ThenTI, "omp.par.entry");
937 BasicBlock *PRegBodyBB =
938 PRegEntryBB->splitBasicBlock(ThenTI, "omp.par.region");
939 BasicBlock *PRegPreFiniBB =
940 PRegBodyBB->splitBasicBlock(ThenTI, "omp.par.pre_finalize");
941 BasicBlock *PRegExitBB =
942 PRegPreFiniBB->splitBasicBlock(ThenTI, "omp.par.exit");
943
944 auto FiniCBWrapper = [&](InsertPointTy IP) {
945 // Hide "open-ended" blocks from the given FiniCB by setting the right jump
946 // target to the region exit block.
947 if (IP.getBlock()->end() == IP.getPoint()) {
948 IRBuilder<>::InsertPointGuard IPG(Builder);
949 Builder.restoreIP(IP);
950 Instruction *I = Builder.CreateBr(PRegExitBB);
951 IP = InsertPointTy(I->getParent(), I->getIterator());
952 }
953 assert(IP.getBlock()->getTerminator()->getNumSuccessors() == 1 &&
954 IP.getBlock()->getTerminator()->getSuccessor(0) == PRegExitBB &&
955 "Unexpected insertion point for finalization call!");
956 return FiniCB(IP);
957 };
958
959 FinalizationStack.push_back({FiniCBWrapper, OMPD_parallel, IsCancellable});
960
961 // Generate the privatization allocas in the block that will become the entry
962 // of the outlined function.
963 Builder.SetInsertPoint(PRegEntryBB->getTerminator());
964 InsertPointTy InnerAllocaIP = Builder.saveIP();
965
966 AllocaInst *PrivTIDAddr =
967 Builder.CreateAlloca(Int32, nullptr, "tid.addr.local");
968 Instruction *PrivTID = Builder.CreateLoad(Int32, PrivTIDAddr, "tid");
969
970 // Add some fake uses for OpenMP provided arguments.
971 ToBeDeleted.push_back(Builder.CreateLoad(Int32, TIDAddr, "tid.addr.use"));
972 Instruction *ZeroAddrUse =
973 Builder.CreateLoad(Int32, ZeroAddr, "zero.addr.use");
974 ToBeDeleted.push_back(ZeroAddrUse);
975
976 // ThenBB
977 // |
978 // V
979 // PRegionEntryBB <- Privatization allocas are placed here.
980 // |
981 // V
982 // PRegionBodyBB <- BodeGen is invoked here.
983 // |
984 // V
985 // PRegPreFiniBB <- The block we will start finalization from.
986 // |
987 // V
988 // PRegionExitBB <- A common exit to simplify block collection.
989 //
990
991 LLVM_DEBUG(dbgs() << "Before body codegen: " << *OuterFn << "\n");
992
993 // Let the caller create the body.
994 assert(BodyGenCB && "Expected body generation callback!");
995 InsertPointTy CodeGenIP(PRegBodyBB, PRegBodyBB->begin());
996 BodyGenCB(InnerAllocaIP, CodeGenIP);
997
998 LLVM_DEBUG(dbgs() << "After body codegen: " << *OuterFn << "\n");
999
1000 FunctionCallee RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call);
1001 if (auto *F = dyn_cast<llvm::Function>(RTLFn.getCallee())) {
1002 if (!F->hasMetadata(llvm::LLVMContext::MD_callback)) {
1003 llvm::LLVMContext &Ctx = F->getContext();
1004 MDBuilder MDB(Ctx);
1005 // Annotate the callback behavior of the __kmpc_fork_call:
1006 // - The callback callee is argument number 2 (microtask).
1007 // - The first two arguments of the callback callee are unknown (-1).
1008 // - All variadic arguments to the __kmpc_fork_call are passed to the
1009 // callback callee.
1010 F->addMetadata(
1011 llvm::LLVMContext::MD_callback,
1012 *llvm::MDNode::get(
1013 Ctx, {MDB.createCallbackEncoding(2, {-1, -1},
1014 /* VarArgsArePassed */ true)}));
1015 }
1016 }
1017
1018 OutlineInfo OI;
1019 OI.PostOutlineCB = [=](Function &OutlinedFn) {
1020 // Add some known attributes.
1021 OutlinedFn.addParamAttr(0, Attribute::NoAlias);
1022 OutlinedFn.addParamAttr(1, Attribute::NoAlias);
1023 OutlinedFn.addFnAttr(Attribute::NoUnwind);
1024 OutlinedFn.addFnAttr(Attribute::NoRecurse);
1025
1026 assert(OutlinedFn.arg_size() >= 2 &&
1027 "Expected at least tid and bounded tid as arguments");
1028 unsigned NumCapturedVars =
1029 OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
1030
1031 CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
1032 CI->getParent()->setName("omp_parallel");
1033 Builder.SetInsertPoint(CI);
1034
1035 // Build call __kmpc_fork_call(Ident, n, microtask, var1, .., varn);
1036 Value *ForkCallArgs[] = {
1037 Ident, Builder.getInt32(NumCapturedVars),
1038 Builder.CreateBitCast(&OutlinedFn, ParallelTaskPtr)};
1039
1040 SmallVector<Value *, 16> RealArgs;
1041 RealArgs.append(std::begin(ForkCallArgs), std::end(ForkCallArgs));
1042 RealArgs.append(CI->arg_begin() + /* tid & bound tid */ 2, CI->arg_end());
1043
1044 Builder.CreateCall(RTLFn, RealArgs);
1045
1046 LLVM_DEBUG(dbgs() << "With fork_call placed: "
1047 << *Builder.GetInsertBlock()->getParent() << "\n");
1048
1049 InsertPointTy ExitIP(PRegExitBB, PRegExitBB->end());
1050
1051 // Initialize the local TID stack location with the argument value.
1052 Builder.SetInsertPoint(PrivTID);
1053 Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
1054 Builder.CreateStore(Builder.CreateLoad(Int32, OutlinedAI), PrivTIDAddr);
1055
1056 // If no "if" clause was present we do not need the call created during
1057 // outlining, otherwise we reuse it in the serialized parallel region.
1058 if (!ElseTI) {
1059 CI->eraseFromParent();
1060 } else {
1061
1062 // If an "if" clause was present we are now generating the serialized
1063 // version into the "else" branch.
1064 Builder.SetInsertPoint(ElseTI);
1065
1066 // Build calls __kmpc_serialized_parallel(&Ident, GTid);
1067 Value *SerializedParallelCallArgs[] = {Ident, ThreadID};
1068 Builder.CreateCall(
1069 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_serialized_parallel),
1070 SerializedParallelCallArgs);
1071
1072 // OutlinedFn(>id, &zero, CapturedStruct);
1073 CI->removeFromParent();
1074 Builder.Insert(CI);
1075
1076 // __kmpc_end_serialized_parallel(&Ident, GTid);
1077 Value *EndArgs[] = {Ident, ThreadID};
1078 Builder.CreateCall(
1079 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_serialized_parallel),
1080 EndArgs);
1081
1082 LLVM_DEBUG(dbgs() << "With serialized parallel region: "
1083 << *Builder.GetInsertBlock()->getParent() << "\n");
1084 }
1085
1086 for (Instruction *I : ToBeDeleted)
1087 I->eraseFromParent();
1088 };
1089
1090 // Adjust the finalization stack, verify the adjustment, and call the
1091 // finalize function a last time to finalize values between the pre-fini
1092 // block and the exit block if we left the parallel "the normal way".
1093 auto FiniInfo = FinalizationStack.pop_back_val();
1094 (void)FiniInfo;
1095 assert(FiniInfo.DK == OMPD_parallel &&
1096 "Unexpected finalization stack state!");
1097
1098 Instruction *PRegPreFiniTI = PRegPreFiniBB->getTerminator();
1099
1100 InsertPointTy PreFiniIP(PRegPreFiniBB, PRegPreFiniTI->getIterator());
1101 FiniCB(PreFiniIP);
1102
1103 OI.OuterAllocaBB = OuterAllocaBlock;
1104 OI.EntryBB = PRegEntryBB;
1105 OI.ExitBB = PRegExitBB;
1106
1107 SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
1108 SmallVector<BasicBlock *, 32> Blocks;
1109 OI.collectBlocks(ParallelRegionBlockSet, Blocks);
1110
1111 // Ensure a single exit node for the outlined region by creating one.
1112 // We might have multiple incoming edges to the exit now due to finalizations,
1113 // e.g., cancel calls that cause the control flow to leave the region.
1114 BasicBlock *PRegOutlinedExitBB = PRegExitBB;
1115 PRegExitBB = SplitBlock(PRegExitBB, &*PRegExitBB->getFirstInsertionPt());
1116 PRegOutlinedExitBB->setName("omp.par.outlined.exit");
1117 Blocks.push_back(PRegOutlinedExitBB);
1118
1119 CodeExtractorAnalysisCache CEAC(*OuterFn);
1120 CodeExtractor Extractor(Blocks, /* DominatorTree */ nullptr,
1121 /* AggregateArgs */ false,
1122 /* BlockFrequencyInfo */ nullptr,
1123 /* BranchProbabilityInfo */ nullptr,
1124 /* AssumptionCache */ nullptr,
1125 /* AllowVarArgs */ true,
1126 /* AllowAlloca */ true,
1127 /* AllocationBlock */ OuterAllocaBlock,
1128 /* Suffix */ ".omp_par");
1129
1130 // Find inputs to, outputs from the code region.
1131 BasicBlock *CommonExit = nullptr;
1132 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
1133 Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
1134 Extractor.findInputsOutputs(Inputs, Outputs, SinkingCands);
1135
1136 LLVM_DEBUG(dbgs() << "Before privatization: " << *OuterFn << "\n");
1137
1138 FunctionCallee TIDRTLFn =
1139 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_global_thread_num);
1140
1141 auto PrivHelper = [&](Value &V) {
1142 if (&V == TIDAddr || &V == ZeroAddr) {
1143 OI.ExcludeArgsFromAggregate.push_back(&V);
1144 return;
1145 }
1146
1147 SetVector<Use *> Uses;
1148 for (Use &U : V.uses())
1149 if (auto *UserI = dyn_cast<Instruction>(U.getUser()))
1150 if (ParallelRegionBlockSet.count(UserI->getParent()))
1151 Uses.insert(&U);
1152
1153 // __kmpc_fork_call expects extra arguments as pointers. If the input
1154 // already has a pointer type, everything is fine. Otherwise, store the
1155 // value onto stack and load it back inside the to-be-outlined region. This
1156 // will ensure only the pointer will be passed to the function.
1157 // FIXME: if there are more than 15 trailing arguments, they must be
1158 // additionally packed in a struct.
1159 Value *Inner = &V;
1160 if (!V.getType()->isPointerTy()) {
1161 IRBuilder<>::InsertPointGuard Guard(Builder);
1162 LLVM_DEBUG(llvm::dbgs() << "Forwarding input as pointer: " << V << "\n");
1163
1164 Builder.restoreIP(OuterAllocaIP);
1165 Value *Ptr =
1166 Builder.CreateAlloca(V.getType(), nullptr, V.getName() + ".reloaded");
1167
1168 // Store to stack at end of the block that currently branches to the entry
1169 // block of the to-be-outlined region.
1170 Builder.SetInsertPoint(InsertBB,
1171 InsertBB->getTerminator()->getIterator());
1172 Builder.CreateStore(&V, Ptr);
1173
1174 // Load back next to allocations in the to-be-outlined region.
1175 Builder.restoreIP(InnerAllocaIP);
1176 Inner = Builder.CreateLoad(V.getType(), Ptr);
1177 }
1178
1179 Value *ReplacementValue = nullptr;
1180 CallInst *CI = dyn_cast<CallInst>(&V);
1181 if (CI && CI->getCalledFunction() == TIDRTLFn.getCallee()) {
1182 ReplacementValue = PrivTID;
1183 } else {
1184 Builder.restoreIP(
1185 PrivCB(InnerAllocaIP, Builder.saveIP(), V, *Inner, ReplacementValue));
1186 assert(ReplacementValue &&
1187 "Expected copy/create callback to set replacement value!");
1188 if (ReplacementValue == &V)
1189 return;
1190 }
1191
1192 for (Use *UPtr : Uses)
1193 UPtr->set(ReplacementValue);
1194 };
1195
1196 // Reset the inner alloca insertion as it will be used for loading the values
1197 // wrapped into pointers before passing them into the to-be-outlined region.
1198 // Configure it to insert immediately after the fake use of zero address so
1199 // that they are available in the generated body and so that the
1200 // OpenMP-related values (thread ID and zero address pointers) remain leading
1201 // in the argument list.
1202 InnerAllocaIP = IRBuilder<>::InsertPoint(
1203 ZeroAddrUse->getParent(), ZeroAddrUse->getNextNode()->getIterator());
1204
1205 // Reset the outer alloca insertion point to the entry of the relevant block
1206 // in case it was invalidated.
1207 OuterAllocaIP = IRBuilder<>::InsertPoint(
1208 OuterAllocaBlock, OuterAllocaBlock->getFirstInsertionPt());
1209
1210 for (Value *Input : Inputs) {
1211 LLVM_DEBUG(dbgs() << "Captured input: " << *Input << "\n");
1212 PrivHelper(*Input);
1213 }
1214 LLVM_DEBUG({
1215 for (Value *Output : Outputs)
1216 LLVM_DEBUG(dbgs() << "Captured output: " << *Output << "\n");
1217 });
1218 assert(Outputs.empty() &&
1219 "OpenMP outlining should not produce live-out values!");
1220
1221 LLVM_DEBUG(dbgs() << "After privatization: " << *OuterFn << "\n");
1222 LLVM_DEBUG({
1223 for (auto *BB : Blocks)
1224 dbgs() << " PBR: " << BB->getName() << "\n";
1225 });
1226
1227 // Register the outlined info.
1228 addOutlineInfo(std::move(OI));
1229
1230 InsertPointTy AfterIP(UI->getParent(), UI->getParent()->end());
1231 UI->eraseFromParent();
1232
1233 return AfterIP;
1234 }
1235
emitFlush(const LocationDescription & Loc)1236 void OpenMPIRBuilder::emitFlush(const LocationDescription &Loc) {
1237 // Build call void __kmpc_flush(ident_t *loc)
1238 uint32_t SrcLocStrSize;
1239 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1240 Value *Args[] = {getOrCreateIdent(SrcLocStr, SrcLocStrSize)};
1241
1242 Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_flush), Args);
1243 }
1244
createFlush(const LocationDescription & Loc)1245 void OpenMPIRBuilder::createFlush(const LocationDescription &Loc) {
1246 if (!updateToLocation(Loc))
1247 return;
1248 emitFlush(Loc);
1249 }
1250
emitTaskwaitImpl(const LocationDescription & Loc)1251 void OpenMPIRBuilder::emitTaskwaitImpl(const LocationDescription &Loc) {
1252 // Build call kmp_int32 __kmpc_omp_taskwait(ident_t *loc, kmp_int32
1253 // global_tid);
1254 uint32_t SrcLocStrSize;
1255 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1256 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1257 Value *Args[] = {Ident, getOrCreateThreadID(Ident)};
1258
1259 // Ignore return result until untied tasks are supported.
1260 Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskwait),
1261 Args);
1262 }
1263
createTaskwait(const LocationDescription & Loc)1264 void OpenMPIRBuilder::createTaskwait(const LocationDescription &Loc) {
1265 if (!updateToLocation(Loc))
1266 return;
1267 emitTaskwaitImpl(Loc);
1268 }
1269
emitTaskyieldImpl(const LocationDescription & Loc)1270 void OpenMPIRBuilder::emitTaskyieldImpl(const LocationDescription &Loc) {
1271 // Build call __kmpc_omp_taskyield(loc, thread_id, 0);
1272 uint32_t SrcLocStrSize;
1273 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1274 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1275 Constant *I32Null = ConstantInt::getNullValue(Int32);
1276 Value *Args[] = {Ident, getOrCreateThreadID(Ident), I32Null};
1277
1278 Builder.CreateCall(getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_taskyield),
1279 Args);
1280 }
1281
createTaskyield(const LocationDescription & Loc)1282 void OpenMPIRBuilder::createTaskyield(const LocationDescription &Loc) {
1283 if (!updateToLocation(Loc))
1284 return;
1285 emitTaskyieldImpl(Loc);
1286 }
1287
1288 OpenMPIRBuilder::InsertPointTy
createTask(const LocationDescription & Loc,InsertPointTy AllocaIP,BodyGenCallbackTy BodyGenCB,bool Tied,Value * Final)1289 OpenMPIRBuilder::createTask(const LocationDescription &Loc,
1290 InsertPointTy AllocaIP, BodyGenCallbackTy BodyGenCB,
1291 bool Tied, Value *Final) {
1292 if (!updateToLocation(Loc))
1293 return InsertPointTy();
1294
1295 uint32_t SrcLocStrSize;
1296 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1297 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1298 // The current basic block is split into four basic blocks. After outlining,
1299 // they will be mapped as follows:
1300 // ```
1301 // def current_fn() {
1302 // current_basic_block:
1303 // br label %task.exit
1304 // task.exit:
1305 // ; instructions after task
1306 // }
1307 // def outlined_fn() {
1308 // task.alloca:
1309 // br label %task.body
1310 // task.body:
1311 // ret void
1312 // }
1313 // ```
1314 BasicBlock *TaskExitBB = splitBB(Builder, /*CreateBranch=*/true, "task.exit");
1315 BasicBlock *TaskBodyBB = splitBB(Builder, /*CreateBranch=*/true, "task.body");
1316 BasicBlock *TaskAllocaBB =
1317 splitBB(Builder, /*CreateBranch=*/true, "task.alloca");
1318
1319 OutlineInfo OI;
1320 OI.EntryBB = TaskAllocaBB;
1321 OI.OuterAllocaBB = AllocaIP.getBlock();
1322 OI.ExitBB = TaskExitBB;
1323 OI.PostOutlineCB = [this, Ident, Tied, Final](Function &OutlinedFn) {
1324 // The input IR here looks like the following-
1325 // ```
1326 // func @current_fn() {
1327 // outlined_fn(%args)
1328 // }
1329 // func @outlined_fn(%args) { ... }
1330 // ```
1331 //
1332 // This is changed to the following-
1333 //
1334 // ```
1335 // func @current_fn() {
1336 // runtime_call(..., wrapper_fn, ...)
1337 // }
1338 // func @wrapper_fn(..., %args) {
1339 // outlined_fn(%args)
1340 // }
1341 // func @outlined_fn(%args) { ... }
1342 // ```
1343
1344 // The stale call instruction will be replaced with a new call instruction
1345 // for runtime call with a wrapper function.
1346 assert(OutlinedFn.getNumUses() == 1 &&
1347 "there must be a single user for the outlined function");
1348 CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
1349
1350 // HasTaskData is true if any variables are captured in the outlined region,
1351 // false otherwise.
1352 bool HasTaskData = StaleCI->arg_size() > 0;
1353 Builder.SetInsertPoint(StaleCI);
1354
1355 // Gather the arguments for emitting the runtime call for
1356 // @__kmpc_omp_task_alloc
1357 Function *TaskAllocFn =
1358 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc);
1359
1360 // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
1361 // call.
1362 Value *ThreadID = getOrCreateThreadID(Ident);
1363
1364 // Argument - `flags`
1365 // Task is tied iff (Flags & 1) == 1.
1366 // Task is untied iff (Flags & 1) == 0.
1367 // Task is final iff (Flags & 2) == 2.
1368 // Task is not final iff (Flags & 2) == 0.
1369 // TODO: Handle the other flags.
1370 Value *Flags = Builder.getInt32(Tied);
1371 if (Final) {
1372 Value *FinalFlag =
1373 Builder.CreateSelect(Final, Builder.getInt32(2), Builder.getInt32(0));
1374 Flags = Builder.CreateOr(FinalFlag, Flags);
1375 }
1376
1377 // Argument - `sizeof_kmp_task_t` (TaskSize)
1378 // Tasksize refers to the size in bytes of kmp_task_t data structure
1379 // including private vars accessed in task.
1380 Value *TaskSize = Builder.getInt64(0);
1381 if (HasTaskData) {
1382 AllocaInst *ArgStructAlloca =
1383 dyn_cast<AllocaInst>(StaleCI->getArgOperand(0));
1384 assert(ArgStructAlloca &&
1385 "Unable to find the alloca instruction corresponding to arguments "
1386 "for extracted function");
1387 StructType *ArgStructType =
1388 dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
1389 assert(ArgStructType && "Unable to find struct type corresponding to "
1390 "arguments for extracted function");
1391 TaskSize =
1392 Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
1393 }
1394
1395 // TODO: Argument - sizeof_shareds
1396
1397 // Argument - task_entry (the wrapper function)
1398 // If the outlined function has some captured variables (i.e. HasTaskData is
1399 // true), then the wrapper function will have an additional argument (the
1400 // struct containing captured variables). Otherwise, no such argument will
1401 // be present.
1402 SmallVector<Type *> WrapperArgTys{Builder.getInt32Ty()};
1403 if (HasTaskData)
1404 WrapperArgTys.push_back(OutlinedFn.getArg(0)->getType());
1405 FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
1406 (Twine(OutlinedFn.getName()) + ".wrapper").str(),
1407 FunctionType::get(Builder.getInt32Ty(), WrapperArgTys, false));
1408 Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
1409 PointerType *WrapperFuncBitcastType =
1410 FunctionType::get(Builder.getInt32Ty(),
1411 {Builder.getInt32Ty(), Builder.getInt8PtrTy()}, false)
1412 ->getPointerTo();
1413 Value *WrapperFuncBitcast =
1414 ConstantExpr::getBitCast(WrapperFunc, WrapperFuncBitcastType);
1415
1416 // Emit the @__kmpc_omp_task_alloc runtime call
1417 // The runtime call returns a pointer to an area where the task captured
1418 // variables must be copied before the task is run (NewTaskData)
1419 CallInst *NewTaskData = Builder.CreateCall(
1420 TaskAllocFn,
1421 {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
1422 /*sizeof_task=*/TaskSize, /*sizeof_shared=*/Builder.getInt64(0),
1423 /*task_func=*/WrapperFuncBitcast});
1424
1425 // Copy the arguments for outlined function
1426 if (HasTaskData) {
1427 Value *TaskData = StaleCI->getArgOperand(0);
1428 Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
1429 Builder.CreateMemCpy(NewTaskData, Alignment, TaskData, Alignment,
1430 TaskSize);
1431 }
1432
1433 // Emit the @__kmpc_omp_task runtime call to spawn the task
1434 Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
1435 Builder.CreateCall(TaskFn, {Ident, ThreadID, NewTaskData});
1436
1437 StaleCI->eraseFromParent();
1438
1439 // Emit the body for wrapper function
1440 BasicBlock *WrapperEntryBB =
1441 BasicBlock::Create(M.getContext(), "", WrapperFunc);
1442 Builder.SetInsertPoint(WrapperEntryBB);
1443 if (HasTaskData)
1444 Builder.CreateCall(&OutlinedFn, {WrapperFunc->getArg(1)});
1445 else
1446 Builder.CreateCall(&OutlinedFn);
1447 Builder.CreateRet(Builder.getInt32(0));
1448 };
1449
1450 addOutlineInfo(std::move(OI));
1451
1452 InsertPointTy TaskAllocaIP =
1453 InsertPointTy(TaskAllocaBB, TaskAllocaBB->begin());
1454 InsertPointTy TaskBodyIP = InsertPointTy(TaskBodyBB, TaskBodyBB->begin());
1455 BodyGenCB(TaskAllocaIP, TaskBodyIP);
1456 Builder.SetInsertPoint(TaskExitBB, TaskExitBB->begin());
1457
1458 return Builder.saveIP();
1459 }
1460
1461 OpenMPIRBuilder::InsertPointTy
createTaskgroup(const LocationDescription & Loc,InsertPointTy AllocaIP,BodyGenCallbackTy BodyGenCB)1462 OpenMPIRBuilder::createTaskgroup(const LocationDescription &Loc,
1463 InsertPointTy AllocaIP,
1464 BodyGenCallbackTy BodyGenCB) {
1465 if (!updateToLocation(Loc))
1466 return InsertPointTy();
1467
1468 uint32_t SrcLocStrSize;
1469 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1470 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1471 Value *ThreadID = getOrCreateThreadID(Ident);
1472
1473 // Emit the @__kmpc_taskgroup runtime call to start the taskgroup
1474 Function *TaskgroupFn =
1475 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_taskgroup);
1476 Builder.CreateCall(TaskgroupFn, {Ident, ThreadID});
1477
1478 BasicBlock *TaskgroupExitBB = splitBB(Builder, true, "taskgroup.exit");
1479 BodyGenCB(AllocaIP, Builder.saveIP());
1480
1481 Builder.SetInsertPoint(TaskgroupExitBB);
1482 // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup
1483 Function *EndTaskgroupFn =
1484 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_taskgroup);
1485 Builder.CreateCall(EndTaskgroupFn, {Ident, ThreadID});
1486
1487 return Builder.saveIP();
1488 }
1489
createSections(const LocationDescription & Loc,InsertPointTy AllocaIP,ArrayRef<StorableBodyGenCallbackTy> SectionCBs,PrivatizeCallbackTy PrivCB,FinalizeCallbackTy FiniCB,bool IsCancellable,bool IsNowait)1490 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSections(
1491 const LocationDescription &Loc, InsertPointTy AllocaIP,
1492 ArrayRef<StorableBodyGenCallbackTy> SectionCBs, PrivatizeCallbackTy PrivCB,
1493 FinalizeCallbackTy FiniCB, bool IsCancellable, bool IsNowait) {
1494 assert(!isConflictIP(AllocaIP, Loc.IP) && "Dedicated IP allocas required");
1495
1496 if (!updateToLocation(Loc))
1497 return Loc.IP;
1498
1499 auto FiniCBWrapper = [&](InsertPointTy IP) {
1500 if (IP.getBlock()->end() != IP.getPoint())
1501 return FiniCB(IP);
1502 // This must be done otherwise any nested constructs using FinalizeOMPRegion
1503 // will fail because that function requires the Finalization Basic Block to
1504 // have a terminator, which is already removed by EmitOMPRegionBody.
1505 // IP is currently at cancelation block.
1506 // We need to backtrack to the condition block to fetch
1507 // the exit block and create a branch from cancelation
1508 // to exit block.
1509 IRBuilder<>::InsertPointGuard IPG(Builder);
1510 Builder.restoreIP(IP);
1511 auto *CaseBB = IP.getBlock()->getSinglePredecessor();
1512 auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
1513 auto *ExitBB = CondBB->getTerminator()->getSuccessor(1);
1514 Instruction *I = Builder.CreateBr(ExitBB);
1515 IP = InsertPointTy(I->getParent(), I->getIterator());
1516 return FiniCB(IP);
1517 };
1518
1519 FinalizationStack.push_back({FiniCBWrapper, OMPD_sections, IsCancellable});
1520
1521 // Each section is emitted as a switch case
1522 // Each finalization callback is handled from clang.EmitOMPSectionDirective()
1523 // -> OMP.createSection() which generates the IR for each section
1524 // Iterate through all sections and emit a switch construct:
1525 // switch (IV) {
1526 // case 0:
1527 // <SectionStmt[0]>;
1528 // break;
1529 // ...
1530 // case <NumSection> - 1:
1531 // <SectionStmt[<NumSection> - 1]>;
1532 // break;
1533 // }
1534 // ...
1535 // section_loop.after:
1536 // <FiniCB>;
1537 auto LoopBodyGenCB = [&](InsertPointTy CodeGenIP, Value *IndVar) {
1538 Builder.restoreIP(CodeGenIP);
1539 BasicBlock *Continue =
1540 splitBBWithSuffix(Builder, /*CreateBranch=*/false, ".sections.after");
1541 Function *CurFn = Continue->getParent();
1542 SwitchInst *SwitchStmt = Builder.CreateSwitch(IndVar, Continue);
1543
1544 unsigned CaseNumber = 0;
1545 for (auto SectionCB : SectionCBs) {
1546 BasicBlock *CaseBB = BasicBlock::Create(
1547 M.getContext(), "omp_section_loop.body.case", CurFn, Continue);
1548 SwitchStmt->addCase(Builder.getInt32(CaseNumber), CaseBB);
1549 Builder.SetInsertPoint(CaseBB);
1550 BranchInst *CaseEndBr = Builder.CreateBr(Continue);
1551 SectionCB(InsertPointTy(),
1552 {CaseEndBr->getParent(), CaseEndBr->getIterator()});
1553 CaseNumber++;
1554 }
1555 // remove the existing terminator from body BB since there can be no
1556 // terminators after switch/case
1557 };
1558 // Loop body ends here
1559 // LowerBound, UpperBound, and STride for createCanonicalLoop
1560 Type *I32Ty = Type::getInt32Ty(M.getContext());
1561 Value *LB = ConstantInt::get(I32Ty, 0);
1562 Value *UB = ConstantInt::get(I32Ty, SectionCBs.size());
1563 Value *ST = ConstantInt::get(I32Ty, 1);
1564 llvm::CanonicalLoopInfo *LoopInfo = createCanonicalLoop(
1565 Loc, LoopBodyGenCB, LB, UB, ST, true, false, AllocaIP, "section_loop");
1566 InsertPointTy AfterIP =
1567 applyStaticWorkshareLoop(Loc.DL, LoopInfo, AllocaIP, !IsNowait);
1568
1569 // Apply the finalization callback in LoopAfterBB
1570 auto FiniInfo = FinalizationStack.pop_back_val();
1571 assert(FiniInfo.DK == OMPD_sections &&
1572 "Unexpected finalization stack state!");
1573 if (FinalizeCallbackTy &CB = FiniInfo.FiniCB) {
1574 Builder.restoreIP(AfterIP);
1575 BasicBlock *FiniBB =
1576 splitBBWithSuffix(Builder, /*CreateBranch=*/true, "sections.fini");
1577 CB(Builder.saveIP());
1578 AfterIP = {FiniBB, FiniBB->begin()};
1579 }
1580
1581 return AfterIP;
1582 }
1583
1584 OpenMPIRBuilder::InsertPointTy
createSection(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB)1585 OpenMPIRBuilder::createSection(const LocationDescription &Loc,
1586 BodyGenCallbackTy BodyGenCB,
1587 FinalizeCallbackTy FiniCB) {
1588 if (!updateToLocation(Loc))
1589 return Loc.IP;
1590
1591 auto FiniCBWrapper = [&](InsertPointTy IP) {
1592 if (IP.getBlock()->end() != IP.getPoint())
1593 return FiniCB(IP);
1594 // This must be done otherwise any nested constructs using FinalizeOMPRegion
1595 // will fail because that function requires the Finalization Basic Block to
1596 // have a terminator, which is already removed by EmitOMPRegionBody.
1597 // IP is currently at cancelation block.
1598 // We need to backtrack to the condition block to fetch
1599 // the exit block and create a branch from cancelation
1600 // to exit block.
1601 IRBuilder<>::InsertPointGuard IPG(Builder);
1602 Builder.restoreIP(IP);
1603 auto *CaseBB = Loc.IP.getBlock();
1604 auto *CondBB = CaseBB->getSinglePredecessor()->getSinglePredecessor();
1605 auto *ExitBB = CondBB->getTerminator()->getSuccessor(1);
1606 Instruction *I = Builder.CreateBr(ExitBB);
1607 IP = InsertPointTy(I->getParent(), I->getIterator());
1608 return FiniCB(IP);
1609 };
1610
1611 Directive OMPD = Directive::OMPD_sections;
1612 // Since we are using Finalization Callback here, HasFinalize
1613 // and IsCancellable have to be true
1614 return EmitOMPInlinedRegion(OMPD, nullptr, nullptr, BodyGenCB, FiniCBWrapper,
1615 /*Conditional*/ false, /*hasFinalize*/ true,
1616 /*IsCancellable*/ true);
1617 }
1618
1619 /// Create a function with a unique name and a "void (i8*, i8*)" signature in
1620 /// the given module and return it.
getFreshReductionFunc(Module & M)1621 Function *getFreshReductionFunc(Module &M) {
1622 Type *VoidTy = Type::getVoidTy(M.getContext());
1623 Type *Int8PtrTy = Type::getInt8PtrTy(M.getContext());
1624 auto *FuncTy =
1625 FunctionType::get(VoidTy, {Int8PtrTy, Int8PtrTy}, /* IsVarArg */ false);
1626 return Function::Create(FuncTy, GlobalVariable::InternalLinkage,
1627 M.getDataLayout().getDefaultGlobalsAddressSpace(),
1628 ".omp.reduction.func", &M);
1629 }
1630
createReductions(const LocationDescription & Loc,InsertPointTy AllocaIP,ArrayRef<ReductionInfo> ReductionInfos,bool IsNoWait)1631 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createReductions(
1632 const LocationDescription &Loc, InsertPointTy AllocaIP,
1633 ArrayRef<ReductionInfo> ReductionInfos, bool IsNoWait) {
1634 for (const ReductionInfo &RI : ReductionInfos) {
1635 (void)RI;
1636 assert(RI.Variable && "expected non-null variable");
1637 assert(RI.PrivateVariable && "expected non-null private variable");
1638 assert(RI.ReductionGen && "expected non-null reduction generator callback");
1639 assert(RI.Variable->getType() == RI.PrivateVariable->getType() &&
1640 "expected variables and their private equivalents to have the same "
1641 "type");
1642 assert(RI.Variable->getType()->isPointerTy() &&
1643 "expected variables to be pointers");
1644 }
1645
1646 if (!updateToLocation(Loc))
1647 return InsertPointTy();
1648
1649 BasicBlock *InsertBlock = Loc.IP.getBlock();
1650 BasicBlock *ContinuationBlock =
1651 InsertBlock->splitBasicBlock(Loc.IP.getPoint(), "reduce.finalize");
1652 InsertBlock->getTerminator()->eraseFromParent();
1653
1654 // Create and populate array of type-erased pointers to private reduction
1655 // values.
1656 unsigned NumReductions = ReductionInfos.size();
1657 Type *RedArrayTy = ArrayType::get(Builder.getInt8PtrTy(), NumReductions);
1658 Builder.restoreIP(AllocaIP);
1659 Value *RedArray = Builder.CreateAlloca(RedArrayTy, nullptr, "red.array");
1660
1661 Builder.SetInsertPoint(InsertBlock, InsertBlock->end());
1662
1663 for (auto En : enumerate(ReductionInfos)) {
1664 unsigned Index = En.index();
1665 const ReductionInfo &RI = En.value();
1666 Value *RedArrayElemPtr = Builder.CreateConstInBoundsGEP2_64(
1667 RedArrayTy, RedArray, 0, Index, "red.array.elem." + Twine(Index));
1668 Value *Casted =
1669 Builder.CreateBitCast(RI.PrivateVariable, Builder.getInt8PtrTy(),
1670 "private.red.var." + Twine(Index) + ".casted");
1671 Builder.CreateStore(Casted, RedArrayElemPtr);
1672 }
1673
1674 // Emit a call to the runtime function that orchestrates the reduction.
1675 // Declare the reduction function in the process.
1676 Function *Func = Builder.GetInsertBlock()->getParent();
1677 Module *Module = Func->getParent();
1678 Value *RedArrayPtr =
1679 Builder.CreateBitCast(RedArray, Builder.getInt8PtrTy(), "red.array.ptr");
1680 uint32_t SrcLocStrSize;
1681 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1682 bool CanGenerateAtomic =
1683 llvm::all_of(ReductionInfos, [](const ReductionInfo &RI) {
1684 return RI.AtomicReductionGen;
1685 });
1686 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize,
1687 CanGenerateAtomic
1688 ? IdentFlag::OMP_IDENT_FLAG_ATOMIC_REDUCE
1689 : IdentFlag(0));
1690 Value *ThreadId = getOrCreateThreadID(Ident);
1691 Constant *NumVariables = Builder.getInt32(NumReductions);
1692 const DataLayout &DL = Module->getDataLayout();
1693 unsigned RedArrayByteSize = DL.getTypeStoreSize(RedArrayTy);
1694 Constant *RedArraySize = Builder.getInt64(RedArrayByteSize);
1695 Function *ReductionFunc = getFreshReductionFunc(*Module);
1696 Value *Lock = getOMPCriticalRegionLock(".reduction");
1697 Function *ReduceFunc = getOrCreateRuntimeFunctionPtr(
1698 IsNoWait ? RuntimeFunction::OMPRTL___kmpc_reduce_nowait
1699 : RuntimeFunction::OMPRTL___kmpc_reduce);
1700 CallInst *ReduceCall =
1701 Builder.CreateCall(ReduceFunc,
1702 {Ident, ThreadId, NumVariables, RedArraySize,
1703 RedArrayPtr, ReductionFunc, Lock},
1704 "reduce");
1705
1706 // Create final reduction entry blocks for the atomic and non-atomic case.
1707 // Emit IR that dispatches control flow to one of the blocks based on the
1708 // reduction supporting the atomic mode.
1709 BasicBlock *NonAtomicRedBlock =
1710 BasicBlock::Create(Module->getContext(), "reduce.switch.nonatomic", Func);
1711 BasicBlock *AtomicRedBlock =
1712 BasicBlock::Create(Module->getContext(), "reduce.switch.atomic", Func);
1713 SwitchInst *Switch =
1714 Builder.CreateSwitch(ReduceCall, ContinuationBlock, /* NumCases */ 2);
1715 Switch->addCase(Builder.getInt32(1), NonAtomicRedBlock);
1716 Switch->addCase(Builder.getInt32(2), AtomicRedBlock);
1717
1718 // Populate the non-atomic reduction using the elementwise reduction function.
1719 // This loads the elements from the global and private variables and reduces
1720 // them before storing back the result to the global variable.
1721 Builder.SetInsertPoint(NonAtomicRedBlock);
1722 for (auto En : enumerate(ReductionInfos)) {
1723 const ReductionInfo &RI = En.value();
1724 Type *ValueType = RI.ElementType;
1725 Value *RedValue = Builder.CreateLoad(ValueType, RI.Variable,
1726 "red.value." + Twine(En.index()));
1727 Value *PrivateRedValue =
1728 Builder.CreateLoad(ValueType, RI.PrivateVariable,
1729 "red.private.value." + Twine(En.index()));
1730 Value *Reduced;
1731 Builder.restoreIP(
1732 RI.ReductionGen(Builder.saveIP(), RedValue, PrivateRedValue, Reduced));
1733 if (!Builder.GetInsertBlock())
1734 return InsertPointTy();
1735 Builder.CreateStore(Reduced, RI.Variable);
1736 }
1737 Function *EndReduceFunc = getOrCreateRuntimeFunctionPtr(
1738 IsNoWait ? RuntimeFunction::OMPRTL___kmpc_end_reduce_nowait
1739 : RuntimeFunction::OMPRTL___kmpc_end_reduce);
1740 Builder.CreateCall(EndReduceFunc, {Ident, ThreadId, Lock});
1741 Builder.CreateBr(ContinuationBlock);
1742
1743 // Populate the atomic reduction using the atomic elementwise reduction
1744 // function. There are no loads/stores here because they will be happening
1745 // inside the atomic elementwise reduction.
1746 Builder.SetInsertPoint(AtomicRedBlock);
1747 if (CanGenerateAtomic) {
1748 for (const ReductionInfo &RI : ReductionInfos) {
1749 Builder.restoreIP(RI.AtomicReductionGen(Builder.saveIP(), RI.ElementType,
1750 RI.Variable, RI.PrivateVariable));
1751 if (!Builder.GetInsertBlock())
1752 return InsertPointTy();
1753 }
1754 Builder.CreateBr(ContinuationBlock);
1755 } else {
1756 Builder.CreateUnreachable();
1757 }
1758
1759 // Populate the outlined reduction function using the elementwise reduction
1760 // function. Partial values are extracted from the type-erased array of
1761 // pointers to private variables.
1762 BasicBlock *ReductionFuncBlock =
1763 BasicBlock::Create(Module->getContext(), "", ReductionFunc);
1764 Builder.SetInsertPoint(ReductionFuncBlock);
1765 Value *LHSArrayPtr = Builder.CreateBitCast(ReductionFunc->getArg(0),
1766 RedArrayTy->getPointerTo());
1767 Value *RHSArrayPtr = Builder.CreateBitCast(ReductionFunc->getArg(1),
1768 RedArrayTy->getPointerTo());
1769 for (auto En : enumerate(ReductionInfos)) {
1770 const ReductionInfo &RI = En.value();
1771 Value *LHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
1772 RedArrayTy, LHSArrayPtr, 0, En.index());
1773 Value *LHSI8Ptr = Builder.CreateLoad(Builder.getInt8PtrTy(), LHSI8PtrPtr);
1774 Value *LHSPtr = Builder.CreateBitCast(LHSI8Ptr, RI.Variable->getType());
1775 Value *LHS = Builder.CreateLoad(RI.ElementType, LHSPtr);
1776 Value *RHSI8PtrPtr = Builder.CreateConstInBoundsGEP2_64(
1777 RedArrayTy, RHSArrayPtr, 0, En.index());
1778 Value *RHSI8Ptr = Builder.CreateLoad(Builder.getInt8PtrTy(), RHSI8PtrPtr);
1779 Value *RHSPtr =
1780 Builder.CreateBitCast(RHSI8Ptr, RI.PrivateVariable->getType());
1781 Value *RHS = Builder.CreateLoad(RI.ElementType, RHSPtr);
1782 Value *Reduced;
1783 Builder.restoreIP(RI.ReductionGen(Builder.saveIP(), LHS, RHS, Reduced));
1784 if (!Builder.GetInsertBlock())
1785 return InsertPointTy();
1786 Builder.CreateStore(Reduced, LHSPtr);
1787 }
1788 Builder.CreateRetVoid();
1789
1790 Builder.SetInsertPoint(ContinuationBlock);
1791 return Builder.saveIP();
1792 }
1793
1794 OpenMPIRBuilder::InsertPointTy
createMaster(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB)1795 OpenMPIRBuilder::createMaster(const LocationDescription &Loc,
1796 BodyGenCallbackTy BodyGenCB,
1797 FinalizeCallbackTy FiniCB) {
1798
1799 if (!updateToLocation(Loc))
1800 return Loc.IP;
1801
1802 Directive OMPD = Directive::OMPD_master;
1803 uint32_t SrcLocStrSize;
1804 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1805 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1806 Value *ThreadId = getOrCreateThreadID(Ident);
1807 Value *Args[] = {Ident, ThreadId};
1808
1809 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_master);
1810 Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
1811
1812 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_master);
1813 Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
1814
1815 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
1816 /*Conditional*/ true, /*hasFinalize*/ true);
1817 }
1818
1819 OpenMPIRBuilder::InsertPointTy
createMasked(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,Value * Filter)1820 OpenMPIRBuilder::createMasked(const LocationDescription &Loc,
1821 BodyGenCallbackTy BodyGenCB,
1822 FinalizeCallbackTy FiniCB, Value *Filter) {
1823 if (!updateToLocation(Loc))
1824 return Loc.IP;
1825
1826 Directive OMPD = Directive::OMPD_masked;
1827 uint32_t SrcLocStrSize;
1828 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
1829 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
1830 Value *ThreadId = getOrCreateThreadID(Ident);
1831 Value *Args[] = {Ident, ThreadId, Filter};
1832 Value *ArgsEnd[] = {Ident, ThreadId};
1833
1834 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_masked);
1835 Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
1836
1837 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_masked);
1838 Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, ArgsEnd);
1839
1840 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
1841 /*Conditional*/ true, /*hasFinalize*/ true);
1842 }
1843
createLoopSkeleton(DebugLoc DL,Value * TripCount,Function * F,BasicBlock * PreInsertBefore,BasicBlock * PostInsertBefore,const Twine & Name)1844 CanonicalLoopInfo *OpenMPIRBuilder::createLoopSkeleton(
1845 DebugLoc DL, Value *TripCount, Function *F, BasicBlock *PreInsertBefore,
1846 BasicBlock *PostInsertBefore, const Twine &Name) {
1847 Module *M = F->getParent();
1848 LLVMContext &Ctx = M->getContext();
1849 Type *IndVarTy = TripCount->getType();
1850
1851 // Create the basic block structure.
1852 BasicBlock *Preheader =
1853 BasicBlock::Create(Ctx, "omp_" + Name + ".preheader", F, PreInsertBefore);
1854 BasicBlock *Header =
1855 BasicBlock::Create(Ctx, "omp_" + Name + ".header", F, PreInsertBefore);
1856 BasicBlock *Cond =
1857 BasicBlock::Create(Ctx, "omp_" + Name + ".cond", F, PreInsertBefore);
1858 BasicBlock *Body =
1859 BasicBlock::Create(Ctx, "omp_" + Name + ".body", F, PreInsertBefore);
1860 BasicBlock *Latch =
1861 BasicBlock::Create(Ctx, "omp_" + Name + ".inc", F, PostInsertBefore);
1862 BasicBlock *Exit =
1863 BasicBlock::Create(Ctx, "omp_" + Name + ".exit", F, PostInsertBefore);
1864 BasicBlock *After =
1865 BasicBlock::Create(Ctx, "omp_" + Name + ".after", F, PostInsertBefore);
1866
1867 // Use specified DebugLoc for new instructions.
1868 Builder.SetCurrentDebugLocation(DL);
1869
1870 Builder.SetInsertPoint(Preheader);
1871 Builder.CreateBr(Header);
1872
1873 Builder.SetInsertPoint(Header);
1874 PHINode *IndVarPHI = Builder.CreatePHI(IndVarTy, 2, "omp_" + Name + ".iv");
1875 IndVarPHI->addIncoming(ConstantInt::get(IndVarTy, 0), Preheader);
1876 Builder.CreateBr(Cond);
1877
1878 Builder.SetInsertPoint(Cond);
1879 Value *Cmp =
1880 Builder.CreateICmpULT(IndVarPHI, TripCount, "omp_" + Name + ".cmp");
1881 Builder.CreateCondBr(Cmp, Body, Exit);
1882
1883 Builder.SetInsertPoint(Body);
1884 Builder.CreateBr(Latch);
1885
1886 Builder.SetInsertPoint(Latch);
1887 Value *Next = Builder.CreateAdd(IndVarPHI, ConstantInt::get(IndVarTy, 1),
1888 "omp_" + Name + ".next", /*HasNUW=*/true);
1889 Builder.CreateBr(Header);
1890 IndVarPHI->addIncoming(Next, Latch);
1891
1892 Builder.SetInsertPoint(Exit);
1893 Builder.CreateBr(After);
1894
1895 // Remember and return the canonical control flow.
1896 LoopInfos.emplace_front();
1897 CanonicalLoopInfo *CL = &LoopInfos.front();
1898
1899 CL->Header = Header;
1900 CL->Cond = Cond;
1901 CL->Latch = Latch;
1902 CL->Exit = Exit;
1903
1904 #ifndef NDEBUG
1905 CL->assertOK();
1906 #endif
1907 return CL;
1908 }
1909
1910 CanonicalLoopInfo *
createCanonicalLoop(const LocationDescription & Loc,LoopBodyGenCallbackTy BodyGenCB,Value * TripCount,const Twine & Name)1911 OpenMPIRBuilder::createCanonicalLoop(const LocationDescription &Loc,
1912 LoopBodyGenCallbackTy BodyGenCB,
1913 Value *TripCount, const Twine &Name) {
1914 BasicBlock *BB = Loc.IP.getBlock();
1915 BasicBlock *NextBB = BB->getNextNode();
1916
1917 CanonicalLoopInfo *CL = createLoopSkeleton(Loc.DL, TripCount, BB->getParent(),
1918 NextBB, NextBB, Name);
1919 BasicBlock *After = CL->getAfter();
1920
1921 // If location is not set, don't connect the loop.
1922 if (updateToLocation(Loc)) {
1923 // Split the loop at the insertion point: Branch to the preheader and move
1924 // every following instruction to after the loop (the After BB). Also, the
1925 // new successor is the loop's after block.
1926 spliceBB(Builder, After, /*CreateBranch=*/false);
1927 Builder.CreateBr(CL->getPreheader());
1928 }
1929
1930 // Emit the body content. We do it after connecting the loop to the CFG to
1931 // avoid that the callback encounters degenerate BBs.
1932 BodyGenCB(CL->getBodyIP(), CL->getIndVar());
1933
1934 #ifndef NDEBUG
1935 CL->assertOK();
1936 #endif
1937 return CL;
1938 }
1939
createCanonicalLoop(const LocationDescription & Loc,LoopBodyGenCallbackTy BodyGenCB,Value * Start,Value * Stop,Value * Step,bool IsSigned,bool InclusiveStop,InsertPointTy ComputeIP,const Twine & Name)1940 CanonicalLoopInfo *OpenMPIRBuilder::createCanonicalLoop(
1941 const LocationDescription &Loc, LoopBodyGenCallbackTy BodyGenCB,
1942 Value *Start, Value *Stop, Value *Step, bool IsSigned, bool InclusiveStop,
1943 InsertPointTy ComputeIP, const Twine &Name) {
1944
1945 // Consider the following difficulties (assuming 8-bit signed integers):
1946 // * Adding \p Step to the loop counter which passes \p Stop may overflow:
1947 // DO I = 1, 100, 50
1948 /// * A \p Step of INT_MIN cannot not be normalized to a positive direction:
1949 // DO I = 100, 0, -128
1950
1951 // Start, Stop and Step must be of the same integer type.
1952 auto *IndVarTy = cast<IntegerType>(Start->getType());
1953 assert(IndVarTy == Stop->getType() && "Stop type mismatch");
1954 assert(IndVarTy == Step->getType() && "Step type mismatch");
1955
1956 LocationDescription ComputeLoc =
1957 ComputeIP.isSet() ? LocationDescription(ComputeIP, Loc.DL) : Loc;
1958 updateToLocation(ComputeLoc);
1959
1960 ConstantInt *Zero = ConstantInt::get(IndVarTy, 0);
1961 ConstantInt *One = ConstantInt::get(IndVarTy, 1);
1962
1963 // Like Step, but always positive.
1964 Value *Incr = Step;
1965
1966 // Distance between Start and Stop; always positive.
1967 Value *Span;
1968
1969 // Condition whether there are no iterations are executed at all, e.g. because
1970 // UB < LB.
1971 Value *ZeroCmp;
1972
1973 if (IsSigned) {
1974 // Ensure that increment is positive. If not, negate and invert LB and UB.
1975 Value *IsNeg = Builder.CreateICmpSLT(Step, Zero);
1976 Incr = Builder.CreateSelect(IsNeg, Builder.CreateNeg(Step), Step);
1977 Value *LB = Builder.CreateSelect(IsNeg, Stop, Start);
1978 Value *UB = Builder.CreateSelect(IsNeg, Start, Stop);
1979 Span = Builder.CreateSub(UB, LB, "", false, true);
1980 ZeroCmp = Builder.CreateICmp(
1981 InclusiveStop ? CmpInst::ICMP_SLT : CmpInst::ICMP_SLE, UB, LB);
1982 } else {
1983 Span = Builder.CreateSub(Stop, Start, "", true);
1984 ZeroCmp = Builder.CreateICmp(
1985 InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, Stop, Start);
1986 }
1987
1988 Value *CountIfLooping;
1989 if (InclusiveStop) {
1990 CountIfLooping = Builder.CreateAdd(Builder.CreateUDiv(Span, Incr), One);
1991 } else {
1992 // Avoid incrementing past stop since it could overflow.
1993 Value *CountIfTwo = Builder.CreateAdd(
1994 Builder.CreateUDiv(Builder.CreateSub(Span, One), Incr), One);
1995 Value *OneCmp = Builder.CreateICmp(
1996 InclusiveStop ? CmpInst::ICMP_ULT : CmpInst::ICMP_ULE, Span, Incr);
1997 CountIfLooping = Builder.CreateSelect(OneCmp, One, CountIfTwo);
1998 }
1999 Value *TripCount = Builder.CreateSelect(ZeroCmp, Zero, CountIfLooping,
2000 "omp_" + Name + ".tripcount");
2001
2002 auto BodyGen = [=](InsertPointTy CodeGenIP, Value *IV) {
2003 Builder.restoreIP(CodeGenIP);
2004 Value *Span = Builder.CreateMul(IV, Step);
2005 Value *IndVar = Builder.CreateAdd(Span, Start);
2006 BodyGenCB(Builder.saveIP(), IndVar);
2007 };
2008 LocationDescription LoopLoc = ComputeIP.isSet() ? Loc.IP : Builder.saveIP();
2009 return createCanonicalLoop(LoopLoc, BodyGen, TripCount, Name);
2010 }
2011
2012 // Returns an LLVM function to call for initializing loop bounds using OpenMP
2013 // static scheduling depending on `type`. Only i32 and i64 are supported by the
2014 // runtime. Always interpret integers as unsigned similarly to
2015 // CanonicalLoopInfo.
getKmpcForStaticInitForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)2016 static FunctionCallee getKmpcForStaticInitForType(Type *Ty, Module &M,
2017 OpenMPIRBuilder &OMPBuilder) {
2018 unsigned Bitwidth = Ty->getIntegerBitWidth();
2019 if (Bitwidth == 32)
2020 return OMPBuilder.getOrCreateRuntimeFunction(
2021 M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_4u);
2022 if (Bitwidth == 64)
2023 return OMPBuilder.getOrCreateRuntimeFunction(
2024 M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_init_8u);
2025 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
2026 }
2027
2028 OpenMPIRBuilder::InsertPointTy
applyStaticWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,bool NeedsBarrier)2029 OpenMPIRBuilder::applyStaticWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI,
2030 InsertPointTy AllocaIP,
2031 bool NeedsBarrier) {
2032 assert(CLI->isValid() && "Requires a valid canonical loop");
2033 assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
2034 "Require dedicated allocate IP");
2035
2036 // Set up the source location value for OpenMP runtime.
2037 Builder.restoreIP(CLI->getPreheaderIP());
2038 Builder.SetCurrentDebugLocation(DL);
2039
2040 uint32_t SrcLocStrSize;
2041 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
2042 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2043
2044 // Declare useful OpenMP runtime functions.
2045 Value *IV = CLI->getIndVar();
2046 Type *IVTy = IV->getType();
2047 FunctionCallee StaticInit = getKmpcForStaticInitForType(IVTy, M, *this);
2048 FunctionCallee StaticFini =
2049 getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
2050
2051 // Allocate space for computed loop bounds as expected by the "init" function.
2052 Builder.restoreIP(AllocaIP);
2053 Type *I32Type = Type::getInt32Ty(M.getContext());
2054 Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
2055 Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
2056 Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
2057 Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
2058
2059 // At the end of the preheader, prepare for calling the "init" function by
2060 // storing the current loop bounds into the allocated space. A canonical loop
2061 // always iterates from 0 to trip-count with step 1. Note that "init" expects
2062 // and produces an inclusive upper bound.
2063 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
2064 Constant *Zero = ConstantInt::get(IVTy, 0);
2065 Constant *One = ConstantInt::get(IVTy, 1);
2066 Builder.CreateStore(Zero, PLowerBound);
2067 Value *UpperBound = Builder.CreateSub(CLI->getTripCount(), One);
2068 Builder.CreateStore(UpperBound, PUpperBound);
2069 Builder.CreateStore(One, PStride);
2070
2071 Value *ThreadNum = getOrCreateThreadID(SrcLoc);
2072
2073 Constant *SchedulingType = ConstantInt::get(
2074 I32Type, static_cast<int>(OMPScheduleType::UnorderedStatic));
2075
2076 // Call the "init" function and update the trip count of the loop with the
2077 // value it produced.
2078 Builder.CreateCall(StaticInit,
2079 {SrcLoc, ThreadNum, SchedulingType, PLastIter, PLowerBound,
2080 PUpperBound, PStride, One, Zero});
2081 Value *LowerBound = Builder.CreateLoad(IVTy, PLowerBound);
2082 Value *InclusiveUpperBound = Builder.CreateLoad(IVTy, PUpperBound);
2083 Value *TripCountMinusOne = Builder.CreateSub(InclusiveUpperBound, LowerBound);
2084 Value *TripCount = Builder.CreateAdd(TripCountMinusOne, One);
2085 CLI->setTripCount(TripCount);
2086
2087 // Update all uses of the induction variable except the one in the condition
2088 // block that compares it with the actual upper bound, and the increment in
2089 // the latch block.
2090
2091 CLI->mapIndVar([&](Instruction *OldIV) -> Value * {
2092 Builder.SetInsertPoint(CLI->getBody(),
2093 CLI->getBody()->getFirstInsertionPt());
2094 Builder.SetCurrentDebugLocation(DL);
2095 return Builder.CreateAdd(OldIV, LowerBound);
2096 });
2097
2098 // In the "exit" block, call the "fini" function.
2099 Builder.SetInsertPoint(CLI->getExit(),
2100 CLI->getExit()->getTerminator()->getIterator());
2101 Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
2102
2103 // Add the barrier if requested.
2104 if (NeedsBarrier)
2105 createBarrier(LocationDescription(Builder.saveIP(), DL),
2106 omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
2107 /* CheckCancelFlag */ false);
2108
2109 InsertPointTy AfterIP = CLI->getAfterIP();
2110 CLI->invalidate();
2111
2112 return AfterIP;
2113 }
2114
applyStaticChunkedWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,bool NeedsBarrier,Value * ChunkSize)2115 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
2116 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
2117 bool NeedsBarrier, Value *ChunkSize) {
2118 assert(CLI->isValid() && "Requires a valid canonical loop");
2119 assert(ChunkSize && "Chunk size is required");
2120
2121 LLVMContext &Ctx = CLI->getFunction()->getContext();
2122 Value *IV = CLI->getIndVar();
2123 Value *OrigTripCount = CLI->getTripCount();
2124 Type *IVTy = IV->getType();
2125 assert(IVTy->getIntegerBitWidth() <= 64 &&
2126 "Max supported tripcount bitwidth is 64 bits");
2127 Type *InternalIVTy = IVTy->getIntegerBitWidth() <= 32 ? Type::getInt32Ty(Ctx)
2128 : Type::getInt64Ty(Ctx);
2129 Type *I32Type = Type::getInt32Ty(M.getContext());
2130 Constant *Zero = ConstantInt::get(InternalIVTy, 0);
2131 Constant *One = ConstantInt::get(InternalIVTy, 1);
2132
2133 // Declare useful OpenMP runtime functions.
2134 FunctionCallee StaticInit =
2135 getKmpcForStaticInitForType(InternalIVTy, M, *this);
2136 FunctionCallee StaticFini =
2137 getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_for_static_fini);
2138
2139 // Allocate space for computed loop bounds as expected by the "init" function.
2140 Builder.restoreIP(AllocaIP);
2141 Builder.SetCurrentDebugLocation(DL);
2142 Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
2143 Value *PLowerBound =
2144 Builder.CreateAlloca(InternalIVTy, nullptr, "p.lowerbound");
2145 Value *PUpperBound =
2146 Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound");
2147 Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride");
2148
2149 // Set up the source location value for the OpenMP runtime.
2150 Builder.restoreIP(CLI->getPreheaderIP());
2151 Builder.SetCurrentDebugLocation(DL);
2152
2153 // TODO: Detect overflow in ubsan or max-out with current tripcount.
2154 Value *CastedChunkSize =
2155 Builder.CreateZExtOrTrunc(ChunkSize, InternalIVTy, "chunksize");
2156 Value *CastedTripCount =
2157 Builder.CreateZExt(OrigTripCount, InternalIVTy, "tripcount");
2158
2159 Constant *SchedulingType = ConstantInt::get(
2160 I32Type, static_cast<int>(OMPScheduleType::UnorderedStaticChunked));
2161 Builder.CreateStore(Zero, PLowerBound);
2162 Value *OrigUpperBound = Builder.CreateSub(CastedTripCount, One);
2163 Builder.CreateStore(OrigUpperBound, PUpperBound);
2164 Builder.CreateStore(One, PStride);
2165
2166 // Call the "init" function and update the trip count of the loop with the
2167 // value it produced.
2168 uint32_t SrcLocStrSize;
2169 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
2170 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2171 Value *ThreadNum = getOrCreateThreadID(SrcLoc);
2172 Builder.CreateCall(StaticInit,
2173 {/*loc=*/SrcLoc, /*global_tid=*/ThreadNum,
2174 /*schedtype=*/SchedulingType, /*plastiter=*/PLastIter,
2175 /*plower=*/PLowerBound, /*pupper=*/PUpperBound,
2176 /*pstride=*/PStride, /*incr=*/One,
2177 /*chunk=*/CastedChunkSize});
2178
2179 // Load values written by the "init" function.
2180 Value *FirstChunkStart =
2181 Builder.CreateLoad(InternalIVTy, PLowerBound, "omp_firstchunk.lb");
2182 Value *FirstChunkStop =
2183 Builder.CreateLoad(InternalIVTy, PUpperBound, "omp_firstchunk.ub");
2184 Value *FirstChunkEnd = Builder.CreateAdd(FirstChunkStop, One);
2185 Value *ChunkRange =
2186 Builder.CreateSub(FirstChunkEnd, FirstChunkStart, "omp_chunk.range");
2187 Value *NextChunkStride =
2188 Builder.CreateLoad(InternalIVTy, PStride, "omp_dispatch.stride");
2189
2190 // Create outer "dispatch" loop for enumerating the chunks.
2191 BasicBlock *DispatchEnter = splitBB(Builder, true);
2192 Value *DispatchCounter;
2193 CanonicalLoopInfo *DispatchCLI = createCanonicalLoop(
2194 {Builder.saveIP(), DL},
2195 [&](InsertPointTy BodyIP, Value *Counter) { DispatchCounter = Counter; },
2196 FirstChunkStart, CastedTripCount, NextChunkStride,
2197 /*IsSigned=*/false, /*InclusiveStop=*/false, /*ComputeIP=*/{},
2198 "dispatch");
2199
2200 // Remember the BasicBlocks of the dispatch loop we need, then invalidate to
2201 // not have to preserve the canonical invariant.
2202 BasicBlock *DispatchBody = DispatchCLI->getBody();
2203 BasicBlock *DispatchLatch = DispatchCLI->getLatch();
2204 BasicBlock *DispatchExit = DispatchCLI->getExit();
2205 BasicBlock *DispatchAfter = DispatchCLI->getAfter();
2206 DispatchCLI->invalidate();
2207
2208 // Rewire the original loop to become the chunk loop inside the dispatch loop.
2209 redirectTo(DispatchAfter, CLI->getAfter(), DL);
2210 redirectTo(CLI->getExit(), DispatchLatch, DL);
2211 redirectTo(DispatchBody, DispatchEnter, DL);
2212
2213 // Prepare the prolog of the chunk loop.
2214 Builder.restoreIP(CLI->getPreheaderIP());
2215 Builder.SetCurrentDebugLocation(DL);
2216
2217 // Compute the number of iterations of the chunk loop.
2218 Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
2219 Value *ChunkEnd = Builder.CreateAdd(DispatchCounter, ChunkRange);
2220 Value *IsLastChunk =
2221 Builder.CreateICmpUGE(ChunkEnd, CastedTripCount, "omp_chunk.is_last");
2222 Value *CountUntilOrigTripCount =
2223 Builder.CreateSub(CastedTripCount, DispatchCounter);
2224 Value *ChunkTripCount = Builder.CreateSelect(
2225 IsLastChunk, CountUntilOrigTripCount, ChunkRange, "omp_chunk.tripcount");
2226 Value *BackcastedChunkTC =
2227 Builder.CreateTrunc(ChunkTripCount, IVTy, "omp_chunk.tripcount.trunc");
2228 CLI->setTripCount(BackcastedChunkTC);
2229
2230 // Update all uses of the induction variable except the one in the condition
2231 // block that compares it with the actual upper bound, and the increment in
2232 // the latch block.
2233 Value *BackcastedDispatchCounter =
2234 Builder.CreateTrunc(DispatchCounter, IVTy, "omp_dispatch.iv.trunc");
2235 CLI->mapIndVar([&](Instruction *) -> Value * {
2236 Builder.restoreIP(CLI->getBodyIP());
2237 return Builder.CreateAdd(IV, BackcastedDispatchCounter);
2238 });
2239
2240 // In the "exit" block, call the "fini" function.
2241 Builder.SetInsertPoint(DispatchExit, DispatchExit->getFirstInsertionPt());
2242 Builder.CreateCall(StaticFini, {SrcLoc, ThreadNum});
2243
2244 // Add the barrier if requested.
2245 if (NeedsBarrier)
2246 createBarrier(LocationDescription(Builder.saveIP(), DL), OMPD_for,
2247 /*ForceSimpleCall=*/false, /*CheckCancelFlag=*/false);
2248
2249 #ifndef NDEBUG
2250 // Even though we currently do not support applying additional methods to it,
2251 // the chunk loop should remain a canonical loop.
2252 CLI->assertOK();
2253 #endif
2254
2255 return {DispatchAfter, DispatchAfter->getFirstInsertionPt()};
2256 }
2257
applyWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,bool NeedsBarrier,llvm::omp::ScheduleKind SchedKind,llvm::Value * ChunkSize,bool HasSimdModifier,bool HasMonotonicModifier,bool HasNonmonotonicModifier,bool HasOrderedClause)2258 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoop(
2259 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
2260 bool NeedsBarrier, llvm::omp::ScheduleKind SchedKind,
2261 llvm::Value *ChunkSize, bool HasSimdModifier, bool HasMonotonicModifier,
2262 bool HasNonmonotonicModifier, bool HasOrderedClause) {
2263 OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
2264 SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier,
2265 HasNonmonotonicModifier, HasOrderedClause);
2266
2267 bool IsOrdered = (EffectiveScheduleType & OMPScheduleType::ModifierOrdered) ==
2268 OMPScheduleType::ModifierOrdered;
2269 switch (EffectiveScheduleType & ~OMPScheduleType::ModifierMask) {
2270 case OMPScheduleType::BaseStatic:
2271 assert(!ChunkSize && "No chunk size with static-chunked schedule");
2272 if (IsOrdered)
2273 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
2274 NeedsBarrier, ChunkSize);
2275 // FIXME: Monotonicity ignored?
2276 return applyStaticWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier);
2277
2278 case OMPScheduleType::BaseStaticChunked:
2279 if (IsOrdered)
2280 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
2281 NeedsBarrier, ChunkSize);
2282 // FIXME: Monotonicity ignored?
2283 return applyStaticChunkedWorkshareLoop(DL, CLI, AllocaIP, NeedsBarrier,
2284 ChunkSize);
2285
2286 case OMPScheduleType::BaseRuntime:
2287 case OMPScheduleType::BaseAuto:
2288 case OMPScheduleType::BaseGreedy:
2289 case OMPScheduleType::BaseBalanced:
2290 case OMPScheduleType::BaseSteal:
2291 case OMPScheduleType::BaseGuidedSimd:
2292 case OMPScheduleType::BaseRuntimeSimd:
2293 assert(!ChunkSize &&
2294 "schedule type does not support user-defined chunk sizes");
2295 LLVM_FALLTHROUGH;
2296 case OMPScheduleType::BaseDynamicChunked:
2297 case OMPScheduleType::BaseGuidedChunked:
2298 case OMPScheduleType::BaseGuidedIterativeChunked:
2299 case OMPScheduleType::BaseGuidedAnalyticalChunked:
2300 case OMPScheduleType::BaseStaticBalancedChunked:
2301 return applyDynamicWorkshareLoop(DL, CLI, AllocaIP, EffectiveScheduleType,
2302 NeedsBarrier, ChunkSize);
2303
2304 default:
2305 llvm_unreachable("Unknown/unimplemented schedule kind");
2306 }
2307 }
2308
2309 /// Returns an LLVM function to call for initializing loop bounds using OpenMP
2310 /// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
2311 /// the runtime. Always interpret integers as unsigned similarly to
2312 /// CanonicalLoopInfo.
2313 static FunctionCallee
getKmpcForDynamicInitForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)2314 getKmpcForDynamicInitForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
2315 unsigned Bitwidth = Ty->getIntegerBitWidth();
2316 if (Bitwidth == 32)
2317 return OMPBuilder.getOrCreateRuntimeFunction(
2318 M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_4u);
2319 if (Bitwidth == 64)
2320 return OMPBuilder.getOrCreateRuntimeFunction(
2321 M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_init_8u);
2322 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
2323 }
2324
2325 /// Returns an LLVM function to call for updating the next loop using OpenMP
2326 /// dynamic scheduling depending on `type`. Only i32 and i64 are supported by
2327 /// the runtime. Always interpret integers as unsigned similarly to
2328 /// CanonicalLoopInfo.
2329 static FunctionCallee
getKmpcForDynamicNextForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)2330 getKmpcForDynamicNextForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
2331 unsigned Bitwidth = Ty->getIntegerBitWidth();
2332 if (Bitwidth == 32)
2333 return OMPBuilder.getOrCreateRuntimeFunction(
2334 M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_4u);
2335 if (Bitwidth == 64)
2336 return OMPBuilder.getOrCreateRuntimeFunction(
2337 M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_next_8u);
2338 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
2339 }
2340
2341 /// Returns an LLVM function to call for finalizing the dynamic loop using
2342 /// depending on `type`. Only i32 and i64 are supported by the runtime. Always
2343 /// interpret integers as unsigned similarly to CanonicalLoopInfo.
2344 static FunctionCallee
getKmpcForDynamicFiniForType(Type * Ty,Module & M,OpenMPIRBuilder & OMPBuilder)2345 getKmpcForDynamicFiniForType(Type *Ty, Module &M, OpenMPIRBuilder &OMPBuilder) {
2346 unsigned Bitwidth = Ty->getIntegerBitWidth();
2347 if (Bitwidth == 32)
2348 return OMPBuilder.getOrCreateRuntimeFunction(
2349 M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_4u);
2350 if (Bitwidth == 64)
2351 return OMPBuilder.getOrCreateRuntimeFunction(
2352 M, omp::RuntimeFunction::OMPRTL___kmpc_dispatch_fini_8u);
2353 llvm_unreachable("unknown OpenMP loop iterator bitwidth");
2354 }
2355
applyDynamicWorkshareLoop(DebugLoc DL,CanonicalLoopInfo * CLI,InsertPointTy AllocaIP,OMPScheduleType SchedType,bool NeedsBarrier,Value * Chunk)2356 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyDynamicWorkshareLoop(
2357 DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
2358 OMPScheduleType SchedType, bool NeedsBarrier, Value *Chunk) {
2359 assert(CLI->isValid() && "Requires a valid canonical loop");
2360 assert(!isConflictIP(AllocaIP, CLI->getPreheaderIP()) &&
2361 "Require dedicated allocate IP");
2362 assert(isValidWorkshareLoopScheduleType(SchedType) &&
2363 "Require valid schedule type");
2364
2365 bool Ordered = (SchedType & OMPScheduleType::ModifierOrdered) ==
2366 OMPScheduleType::ModifierOrdered;
2367
2368 // Set up the source location value for OpenMP runtime.
2369 Builder.SetCurrentDebugLocation(DL);
2370
2371 uint32_t SrcLocStrSize;
2372 Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
2373 Value *SrcLoc = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
2374
2375 // Declare useful OpenMP runtime functions.
2376 Value *IV = CLI->getIndVar();
2377 Type *IVTy = IV->getType();
2378 FunctionCallee DynamicInit = getKmpcForDynamicInitForType(IVTy, M, *this);
2379 FunctionCallee DynamicNext = getKmpcForDynamicNextForType(IVTy, M, *this);
2380
2381 // Allocate space for computed loop bounds as expected by the "init" function.
2382 Builder.restoreIP(AllocaIP);
2383 Type *I32Type = Type::getInt32Ty(M.getContext());
2384 Value *PLastIter = Builder.CreateAlloca(I32Type, nullptr, "p.lastiter");
2385 Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound");
2386 Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound");
2387 Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride");
2388
2389 // At the end of the preheader, prepare for calling the "init" function by
2390 // storing the current loop bounds into the allocated space. A canonical loop
2391 // always iterates from 0 to trip-count with step 1. Note that "init" expects
2392 // and produces an inclusive upper bound.
2393 BasicBlock *PreHeader = CLI->getPreheader();
2394 Builder.SetInsertPoint(PreHeader->getTerminator());
2395 Constant *One = ConstantInt::get(IVTy, 1);
2396 Builder.CreateStore(One, PLowerBound);
2397 Value *UpperBound = CLI->getTripCount();
2398 Builder.CreateStore(UpperBound, PUpperBound);
2399 Builder.CreateStore(One, PStride);
2400
2401 BasicBlock *Header = CLI->getHeader();
2402 BasicBlock *Exit = CLI->getExit();
2403 BasicBlock *Cond = CLI->getCond();
2404 BasicBlock *Latch = CLI->getLatch();
2405 InsertPointTy AfterIP = CLI->getAfterIP();
2406
2407 // The CLI will be "broken" in the code below, as the loop is no longer
2408 // a valid canonical loop.
2409
2410 if (!Chunk)
2411 Chunk = One;
2412
2413 Value *ThreadNum = getOrCreateThreadID(SrcLoc);
2414
2415 Constant *SchedulingType =
2416 ConstantInt::get(I32Type, static_cast<int>(SchedType));
2417
2418 // Call the "init" function.
2419 Builder.CreateCall(DynamicInit,
2420 {SrcLoc, ThreadNum, SchedulingType, /* LowerBound */ One,
2421 UpperBound, /* step */ One, Chunk});
2422
2423 // An outer loop around the existing one.
2424 BasicBlock *OuterCond = BasicBlock::Create(
2425 PreHeader->getContext(), Twine(PreHeader->getName()) + ".outer.cond",
2426 PreHeader->getParent());
2427 // This needs to be 32-bit always, so can't use the IVTy Zero above.
2428 Builder.SetInsertPoint(OuterCond, OuterCond->getFirstInsertionPt());
2429 Value *Res =
2430 Builder.CreateCall(DynamicNext, {SrcLoc, ThreadNum, PLastIter,
2431 PLowerBound, PUpperBound, PStride});
2432 Constant *Zero32 = ConstantInt::get(I32Type, 0);
2433 Value *MoreWork = Builder.CreateCmp(CmpInst::ICMP_NE, Res, Zero32);
2434 Value *LowerBound =
2435 Builder.CreateSub(Builder.CreateLoad(IVTy, PLowerBound), One, "lb");
2436 Builder.CreateCondBr(MoreWork, Header, Exit);
2437
2438 // Change PHI-node in loop header to use outer cond rather than preheader,
2439 // and set IV to the LowerBound.
2440 Instruction *Phi = &Header->front();
2441 auto *PI = cast<PHINode>(Phi);
2442 PI->setIncomingBlock(0, OuterCond);
2443 PI->setIncomingValue(0, LowerBound);
2444
2445 // Then set the pre-header to jump to the OuterCond
2446 Instruction *Term = PreHeader->getTerminator();
2447 auto *Br = cast<BranchInst>(Term);
2448 Br->setSuccessor(0, OuterCond);
2449
2450 // Modify the inner condition:
2451 // * Use the UpperBound returned from the DynamicNext call.
2452 // * jump to the loop outer loop when done with one of the inner loops.
2453 Builder.SetInsertPoint(Cond, Cond->getFirstInsertionPt());
2454 UpperBound = Builder.CreateLoad(IVTy, PUpperBound, "ub");
2455 Instruction *Comp = &*Builder.GetInsertPoint();
2456 auto *CI = cast<CmpInst>(Comp);
2457 CI->setOperand(1, UpperBound);
2458 // Redirect the inner exit to branch to outer condition.
2459 Instruction *Branch = &Cond->back();
2460 auto *BI = cast<BranchInst>(Branch);
2461 assert(BI->getSuccessor(1) == Exit);
2462 BI->setSuccessor(1, OuterCond);
2463
2464 // Call the "fini" function if "ordered" is present in wsloop directive.
2465 if (Ordered) {
2466 Builder.SetInsertPoint(&Latch->back());
2467 FunctionCallee DynamicFini = getKmpcForDynamicFiniForType(IVTy, M, *this);
2468 Builder.CreateCall(DynamicFini, {SrcLoc, ThreadNum});
2469 }
2470
2471 // Add the barrier if requested.
2472 if (NeedsBarrier) {
2473 Builder.SetInsertPoint(&Exit->back());
2474 createBarrier(LocationDescription(Builder.saveIP(), DL),
2475 omp::Directive::OMPD_for, /* ForceSimpleCall */ false,
2476 /* CheckCancelFlag */ false);
2477 }
2478
2479 CLI->invalidate();
2480 return AfterIP;
2481 }
2482
2483 /// Redirect all edges that branch to \p OldTarget to \p NewTarget. That is,
2484 /// after this \p OldTarget will be orphaned.
redirectAllPredecessorsTo(BasicBlock * OldTarget,BasicBlock * NewTarget,DebugLoc DL)2485 static void redirectAllPredecessorsTo(BasicBlock *OldTarget,
2486 BasicBlock *NewTarget, DebugLoc DL) {
2487 for (BasicBlock *Pred : make_early_inc_range(predecessors(OldTarget)))
2488 redirectTo(Pred, NewTarget, DL);
2489 }
2490
2491 /// Determine which blocks in \p BBs are reachable from outside and remove the
2492 /// ones that are not reachable from the function.
removeUnusedBlocksFromParent(ArrayRef<BasicBlock * > BBs)2493 static void removeUnusedBlocksFromParent(ArrayRef<BasicBlock *> BBs) {
2494 SmallPtrSet<BasicBlock *, 6> BBsToErase{BBs.begin(), BBs.end()};
2495 auto HasRemainingUses = [&BBsToErase](BasicBlock *BB) {
2496 for (Use &U : BB->uses()) {
2497 auto *UseInst = dyn_cast<Instruction>(U.getUser());
2498 if (!UseInst)
2499 continue;
2500 if (BBsToErase.count(UseInst->getParent()))
2501 continue;
2502 return true;
2503 }
2504 return false;
2505 };
2506
2507 while (true) {
2508 bool Changed = false;
2509 for (BasicBlock *BB : make_early_inc_range(BBsToErase)) {
2510 if (HasRemainingUses(BB)) {
2511 BBsToErase.erase(BB);
2512 Changed = true;
2513 }
2514 }
2515 if (!Changed)
2516 break;
2517 }
2518
2519 SmallVector<BasicBlock *, 7> BBVec(BBsToErase.begin(), BBsToErase.end());
2520 DeleteDeadBlocks(BBVec);
2521 }
2522
2523 CanonicalLoopInfo *
collapseLoops(DebugLoc DL,ArrayRef<CanonicalLoopInfo * > Loops,InsertPointTy ComputeIP)2524 OpenMPIRBuilder::collapseLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
2525 InsertPointTy ComputeIP) {
2526 assert(Loops.size() >= 1 && "At least one loop required");
2527 size_t NumLoops = Loops.size();
2528
2529 // Nothing to do if there is already just one loop.
2530 if (NumLoops == 1)
2531 return Loops.front();
2532
2533 CanonicalLoopInfo *Outermost = Loops.front();
2534 CanonicalLoopInfo *Innermost = Loops.back();
2535 BasicBlock *OrigPreheader = Outermost->getPreheader();
2536 BasicBlock *OrigAfter = Outermost->getAfter();
2537 Function *F = OrigPreheader->getParent();
2538
2539 // Loop control blocks that may become orphaned later.
2540 SmallVector<BasicBlock *, 12> OldControlBBs;
2541 OldControlBBs.reserve(6 * Loops.size());
2542 for (CanonicalLoopInfo *Loop : Loops)
2543 Loop->collectControlBlocks(OldControlBBs);
2544
2545 // Setup the IRBuilder for inserting the trip count computation.
2546 Builder.SetCurrentDebugLocation(DL);
2547 if (ComputeIP.isSet())
2548 Builder.restoreIP(ComputeIP);
2549 else
2550 Builder.restoreIP(Outermost->getPreheaderIP());
2551
2552 // Derive the collapsed' loop trip count.
2553 // TODO: Find common/largest indvar type.
2554 Value *CollapsedTripCount = nullptr;
2555 for (CanonicalLoopInfo *L : Loops) {
2556 assert(L->isValid() &&
2557 "All loops to collapse must be valid canonical loops");
2558 Value *OrigTripCount = L->getTripCount();
2559 if (!CollapsedTripCount) {
2560 CollapsedTripCount = OrigTripCount;
2561 continue;
2562 }
2563
2564 // TODO: Enable UndefinedSanitizer to diagnose an overflow here.
2565 CollapsedTripCount = Builder.CreateMul(CollapsedTripCount, OrigTripCount,
2566 {}, /*HasNUW=*/true);
2567 }
2568
2569 // Create the collapsed loop control flow.
2570 CanonicalLoopInfo *Result =
2571 createLoopSkeleton(DL, CollapsedTripCount, F,
2572 OrigPreheader->getNextNode(), OrigAfter, "collapsed");
2573
2574 // Build the collapsed loop body code.
2575 // Start with deriving the input loop induction variables from the collapsed
2576 // one, using a divmod scheme. To preserve the original loops' order, the
2577 // innermost loop use the least significant bits.
2578 Builder.restoreIP(Result->getBodyIP());
2579
2580 Value *Leftover = Result->getIndVar();
2581 SmallVector<Value *> NewIndVars;
2582 NewIndVars.resize(NumLoops);
2583 for (int i = NumLoops - 1; i >= 1; --i) {
2584 Value *OrigTripCount = Loops[i]->getTripCount();
2585
2586 Value *NewIndVar = Builder.CreateURem(Leftover, OrigTripCount);
2587 NewIndVars[i] = NewIndVar;
2588
2589 Leftover = Builder.CreateUDiv(Leftover, OrigTripCount);
2590 }
2591 // Outermost loop gets all the remaining bits.
2592 NewIndVars[0] = Leftover;
2593
2594 // Construct the loop body control flow.
2595 // We progressively construct the branch structure following in direction of
2596 // the control flow, from the leading in-between code, the loop nest body, the
2597 // trailing in-between code, and rejoining the collapsed loop's latch.
2598 // ContinueBlock and ContinuePred keep track of the source(s) of next edge. If
2599 // the ContinueBlock is set, continue with that block. If ContinuePred, use
2600 // its predecessors as sources.
2601 BasicBlock *ContinueBlock = Result->getBody();
2602 BasicBlock *ContinuePred = nullptr;
2603 auto ContinueWith = [&ContinueBlock, &ContinuePred, DL](BasicBlock *Dest,
2604 BasicBlock *NextSrc) {
2605 if (ContinueBlock)
2606 redirectTo(ContinueBlock, Dest, DL);
2607 else
2608 redirectAllPredecessorsTo(ContinuePred, Dest, DL);
2609
2610 ContinueBlock = nullptr;
2611 ContinuePred = NextSrc;
2612 };
2613
2614 // The code before the nested loop of each level.
2615 // Because we are sinking it into the nest, it will be executed more often
2616 // that the original loop. More sophisticated schemes could keep track of what
2617 // the in-between code is and instantiate it only once per thread.
2618 for (size_t i = 0; i < NumLoops - 1; ++i)
2619 ContinueWith(Loops[i]->getBody(), Loops[i + 1]->getHeader());
2620
2621 // Connect the loop nest body.
2622 ContinueWith(Innermost->getBody(), Innermost->getLatch());
2623
2624 // The code after the nested loop at each level.
2625 for (size_t i = NumLoops - 1; i > 0; --i)
2626 ContinueWith(Loops[i]->getAfter(), Loops[i - 1]->getLatch());
2627
2628 // Connect the finished loop to the collapsed loop latch.
2629 ContinueWith(Result->getLatch(), nullptr);
2630
2631 // Replace the input loops with the new collapsed loop.
2632 redirectTo(Outermost->getPreheader(), Result->getPreheader(), DL);
2633 redirectTo(Result->getAfter(), Outermost->getAfter(), DL);
2634
2635 // Replace the input loop indvars with the derived ones.
2636 for (size_t i = 0; i < NumLoops; ++i)
2637 Loops[i]->getIndVar()->replaceAllUsesWith(NewIndVars[i]);
2638
2639 // Remove unused parts of the input loops.
2640 removeUnusedBlocksFromParent(OldControlBBs);
2641
2642 for (CanonicalLoopInfo *L : Loops)
2643 L->invalidate();
2644
2645 #ifndef NDEBUG
2646 Result->assertOK();
2647 #endif
2648 return Result;
2649 }
2650
2651 std::vector<CanonicalLoopInfo *>
tileLoops(DebugLoc DL,ArrayRef<CanonicalLoopInfo * > Loops,ArrayRef<Value * > TileSizes)2652 OpenMPIRBuilder::tileLoops(DebugLoc DL, ArrayRef<CanonicalLoopInfo *> Loops,
2653 ArrayRef<Value *> TileSizes) {
2654 assert(TileSizes.size() == Loops.size() &&
2655 "Must pass as many tile sizes as there are loops");
2656 int NumLoops = Loops.size();
2657 assert(NumLoops >= 1 && "At least one loop to tile required");
2658
2659 CanonicalLoopInfo *OutermostLoop = Loops.front();
2660 CanonicalLoopInfo *InnermostLoop = Loops.back();
2661 Function *F = OutermostLoop->getBody()->getParent();
2662 BasicBlock *InnerEnter = InnermostLoop->getBody();
2663 BasicBlock *InnerLatch = InnermostLoop->getLatch();
2664
2665 // Loop control blocks that may become orphaned later.
2666 SmallVector<BasicBlock *, 12> OldControlBBs;
2667 OldControlBBs.reserve(6 * Loops.size());
2668 for (CanonicalLoopInfo *Loop : Loops)
2669 Loop->collectControlBlocks(OldControlBBs);
2670
2671 // Collect original trip counts and induction variable to be accessible by
2672 // index. Also, the structure of the original loops is not preserved during
2673 // the construction of the tiled loops, so do it before we scavenge the BBs of
2674 // any original CanonicalLoopInfo.
2675 SmallVector<Value *, 4> OrigTripCounts, OrigIndVars;
2676 for (CanonicalLoopInfo *L : Loops) {
2677 assert(L->isValid() && "All input loops must be valid canonical loops");
2678 OrigTripCounts.push_back(L->getTripCount());
2679 OrigIndVars.push_back(L->getIndVar());
2680 }
2681
2682 // Collect the code between loop headers. These may contain SSA definitions
2683 // that are used in the loop nest body. To be usable with in the innermost
2684 // body, these BasicBlocks will be sunk into the loop nest body. That is,
2685 // these instructions may be executed more often than before the tiling.
2686 // TODO: It would be sufficient to only sink them into body of the
2687 // corresponding tile loop.
2688 SmallVector<std::pair<BasicBlock *, BasicBlock *>, 4> InbetweenCode;
2689 for (int i = 0; i < NumLoops - 1; ++i) {
2690 CanonicalLoopInfo *Surrounding = Loops[i];
2691 CanonicalLoopInfo *Nested = Loops[i + 1];
2692
2693 BasicBlock *EnterBB = Surrounding->getBody();
2694 BasicBlock *ExitBB = Nested->getHeader();
2695 InbetweenCode.emplace_back(EnterBB, ExitBB);
2696 }
2697
2698 // Compute the trip counts of the floor loops.
2699 Builder.SetCurrentDebugLocation(DL);
2700 Builder.restoreIP(OutermostLoop->getPreheaderIP());
2701 SmallVector<Value *, 4> FloorCount, FloorRems;
2702 for (int i = 0; i < NumLoops; ++i) {
2703 Value *TileSize = TileSizes[i];
2704 Value *OrigTripCount = OrigTripCounts[i];
2705 Type *IVType = OrigTripCount->getType();
2706
2707 Value *FloorTripCount = Builder.CreateUDiv(OrigTripCount, TileSize);
2708 Value *FloorTripRem = Builder.CreateURem(OrigTripCount, TileSize);
2709
2710 // 0 if tripcount divides the tilesize, 1 otherwise.
2711 // 1 means we need an additional iteration for a partial tile.
2712 //
2713 // Unfortunately we cannot just use the roundup-formula
2714 // (tripcount + tilesize - 1)/tilesize
2715 // because the summation might overflow. We do not want introduce undefined
2716 // behavior when the untiled loop nest did not.
2717 Value *FloorTripOverflow =
2718 Builder.CreateICmpNE(FloorTripRem, ConstantInt::get(IVType, 0));
2719
2720 FloorTripOverflow = Builder.CreateZExt(FloorTripOverflow, IVType);
2721 FloorTripCount =
2722 Builder.CreateAdd(FloorTripCount, FloorTripOverflow,
2723 "omp_floor" + Twine(i) + ".tripcount", true);
2724
2725 // Remember some values for later use.
2726 FloorCount.push_back(FloorTripCount);
2727 FloorRems.push_back(FloorTripRem);
2728 }
2729
2730 // Generate the new loop nest, from the outermost to the innermost.
2731 std::vector<CanonicalLoopInfo *> Result;
2732 Result.reserve(NumLoops * 2);
2733
2734 // The basic block of the surrounding loop that enters the nest generated
2735 // loop.
2736 BasicBlock *Enter = OutermostLoop->getPreheader();
2737
2738 // The basic block of the surrounding loop where the inner code should
2739 // continue.
2740 BasicBlock *Continue = OutermostLoop->getAfter();
2741
2742 // Where the next loop basic block should be inserted.
2743 BasicBlock *OutroInsertBefore = InnermostLoop->getExit();
2744
2745 auto EmbeddNewLoop =
2746 [this, DL, F, InnerEnter, &Enter, &Continue, &OutroInsertBefore](
2747 Value *TripCount, const Twine &Name) -> CanonicalLoopInfo * {
2748 CanonicalLoopInfo *EmbeddedLoop = createLoopSkeleton(
2749 DL, TripCount, F, InnerEnter, OutroInsertBefore, Name);
2750 redirectTo(Enter, EmbeddedLoop->getPreheader(), DL);
2751 redirectTo(EmbeddedLoop->getAfter(), Continue, DL);
2752
2753 // Setup the position where the next embedded loop connects to this loop.
2754 Enter = EmbeddedLoop->getBody();
2755 Continue = EmbeddedLoop->getLatch();
2756 OutroInsertBefore = EmbeddedLoop->getLatch();
2757 return EmbeddedLoop;
2758 };
2759
2760 auto EmbeddNewLoops = [&Result, &EmbeddNewLoop](ArrayRef<Value *> TripCounts,
2761 const Twine &NameBase) {
2762 for (auto P : enumerate(TripCounts)) {
2763 CanonicalLoopInfo *EmbeddedLoop =
2764 EmbeddNewLoop(P.value(), NameBase + Twine(P.index()));
2765 Result.push_back(EmbeddedLoop);
2766 }
2767 };
2768
2769 EmbeddNewLoops(FloorCount, "floor");
2770
2771 // Within the innermost floor loop, emit the code that computes the tile
2772 // sizes.
2773 Builder.SetInsertPoint(Enter->getTerminator());
2774 SmallVector<Value *, 4> TileCounts;
2775 for (int i = 0; i < NumLoops; ++i) {
2776 CanonicalLoopInfo *FloorLoop = Result[i];
2777 Value *TileSize = TileSizes[i];
2778
2779 Value *FloorIsEpilogue =
2780 Builder.CreateICmpEQ(FloorLoop->getIndVar(), FloorCount[i]);
2781 Value *TileTripCount =
2782 Builder.CreateSelect(FloorIsEpilogue, FloorRems[i], TileSize);
2783
2784 TileCounts.push_back(TileTripCount);
2785 }
2786
2787 // Create the tile loops.
2788 EmbeddNewLoops(TileCounts, "tile");
2789
2790 // Insert the inbetween code into the body.
2791 BasicBlock *BodyEnter = Enter;
2792 BasicBlock *BodyEntered = nullptr;
2793 for (std::pair<BasicBlock *, BasicBlock *> P : InbetweenCode) {
2794 BasicBlock *EnterBB = P.first;
2795 BasicBlock *ExitBB = P.second;
2796
2797 if (BodyEnter)
2798 redirectTo(BodyEnter, EnterBB, DL);
2799 else
2800 redirectAllPredecessorsTo(BodyEntered, EnterBB, DL);
2801
2802 BodyEnter = nullptr;
2803 BodyEntered = ExitBB;
2804 }
2805
2806 // Append the original loop nest body into the generated loop nest body.
2807 if (BodyEnter)
2808 redirectTo(BodyEnter, InnerEnter, DL);
2809 else
2810 redirectAllPredecessorsTo(BodyEntered, InnerEnter, DL);
2811 redirectAllPredecessorsTo(InnerLatch, Continue, DL);
2812
2813 // Replace the original induction variable with an induction variable computed
2814 // from the tile and floor induction variables.
2815 Builder.restoreIP(Result.back()->getBodyIP());
2816 for (int i = 0; i < NumLoops; ++i) {
2817 CanonicalLoopInfo *FloorLoop = Result[i];
2818 CanonicalLoopInfo *TileLoop = Result[NumLoops + i];
2819 Value *OrigIndVar = OrigIndVars[i];
2820 Value *Size = TileSizes[i];
2821
2822 Value *Scale =
2823 Builder.CreateMul(Size, FloorLoop->getIndVar(), {}, /*HasNUW=*/true);
2824 Value *Shift =
2825 Builder.CreateAdd(Scale, TileLoop->getIndVar(), {}, /*HasNUW=*/true);
2826 OrigIndVar->replaceAllUsesWith(Shift);
2827 }
2828
2829 // Remove unused parts of the original loops.
2830 removeUnusedBlocksFromParent(OldControlBBs);
2831
2832 for (CanonicalLoopInfo *L : Loops)
2833 L->invalidate();
2834
2835 #ifndef NDEBUG
2836 for (CanonicalLoopInfo *GenL : Result)
2837 GenL->assertOK();
2838 #endif
2839 return Result;
2840 }
2841
2842 /// Attach loop metadata \p Properties to the loop described by \p Loop. If the
2843 /// loop already has metadata, the loop properties are appended.
addLoopMetadata(CanonicalLoopInfo * Loop,ArrayRef<Metadata * > Properties)2844 static void addLoopMetadata(CanonicalLoopInfo *Loop,
2845 ArrayRef<Metadata *> Properties) {
2846 assert(Loop->isValid() && "Expecting a valid CanonicalLoopInfo");
2847
2848 // Nothing to do if no property to attach.
2849 if (Properties.empty())
2850 return;
2851
2852 LLVMContext &Ctx = Loop->getFunction()->getContext();
2853 SmallVector<Metadata *> NewLoopProperties;
2854 NewLoopProperties.push_back(nullptr);
2855
2856 // If the loop already has metadata, prepend it to the new metadata.
2857 BasicBlock *Latch = Loop->getLatch();
2858 assert(Latch && "A valid CanonicalLoopInfo must have a unique latch");
2859 MDNode *Existing = Latch->getTerminator()->getMetadata(LLVMContext::MD_loop);
2860 if (Existing)
2861 append_range(NewLoopProperties, drop_begin(Existing->operands(), 1));
2862
2863 append_range(NewLoopProperties, Properties);
2864 MDNode *LoopID = MDNode::getDistinct(Ctx, NewLoopProperties);
2865 LoopID->replaceOperandWith(0, LoopID);
2866
2867 Latch->getTerminator()->setMetadata(LLVMContext::MD_loop, LoopID);
2868 }
2869
2870 /// Attach llvm.access.group metadata to the memref instructions of \p Block
addSimdMetadata(BasicBlock * Block,MDNode * AccessGroup,LoopInfo & LI)2871 static void addSimdMetadata(BasicBlock *Block, MDNode *AccessGroup,
2872 LoopInfo &LI) {
2873 for (Instruction &I : *Block) {
2874 if (I.mayReadOrWriteMemory()) {
2875 // TODO: This instruction may already have access group from
2876 // other pragmas e.g. #pragma clang loop vectorize. Append
2877 // so that the existing metadata is not overwritten.
2878 I.setMetadata(LLVMContext::MD_access_group, AccessGroup);
2879 }
2880 }
2881 }
2882
unrollLoopFull(DebugLoc,CanonicalLoopInfo * Loop)2883 void OpenMPIRBuilder::unrollLoopFull(DebugLoc, CanonicalLoopInfo *Loop) {
2884 LLVMContext &Ctx = Builder.getContext();
2885 addLoopMetadata(
2886 Loop, {MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
2887 MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.full"))});
2888 }
2889
unrollLoopHeuristic(DebugLoc,CanonicalLoopInfo * Loop)2890 void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) {
2891 LLVMContext &Ctx = Builder.getContext();
2892 addLoopMetadata(
2893 Loop, {
2894 MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
2895 });
2896 }
2897
applySimd(CanonicalLoopInfo * CanonicalLoop,ConstantInt * Simdlen)2898 void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
2899 ConstantInt *Simdlen) {
2900 LLVMContext &Ctx = Builder.getContext();
2901
2902 Function *F = CanonicalLoop->getFunction();
2903
2904 FunctionAnalysisManager FAM;
2905 FAM.registerPass([]() { return DominatorTreeAnalysis(); });
2906 FAM.registerPass([]() { return LoopAnalysis(); });
2907 FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
2908
2909 LoopAnalysis LIA;
2910 LoopInfo &&LI = LIA.run(*F, FAM);
2911
2912 Loop *L = LI.getLoopFor(CanonicalLoop->getHeader());
2913
2914 SmallSet<BasicBlock *, 8> Reachable;
2915
2916 // Get the basic blocks from the loop in which memref instructions
2917 // can be found.
2918 // TODO: Generalize getting all blocks inside a CanonicalizeLoopInfo,
2919 // preferably without running any passes.
2920 for (BasicBlock *Block : L->getBlocks()) {
2921 if (Block == CanonicalLoop->getCond() ||
2922 Block == CanonicalLoop->getHeader())
2923 continue;
2924 Reachable.insert(Block);
2925 }
2926
2927 // Add access group metadata to memory-access instructions.
2928 MDNode *AccessGroup = MDNode::getDistinct(Ctx, {});
2929 for (BasicBlock *BB : Reachable)
2930 addSimdMetadata(BB, AccessGroup, LI);
2931
2932 // Use the above access group metadata to create loop level
2933 // metadata, which should be distinct for each loop.
2934 ConstantAsMetadata *BoolConst =
2935 ConstantAsMetadata::get(ConstantInt::getTrue(Type::getInt1Ty(Ctx)));
2936 // TODO: If the loop has existing parallel access metadata, have
2937 // to combine two lists.
2938 addLoopMetadata(
2939 CanonicalLoop,
2940 {MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.parallel_accesses"),
2941 AccessGroup}),
2942 MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"),
2943 BoolConst})});
2944 if (Simdlen != nullptr)
2945 addLoopMetadata(
2946 CanonicalLoop,
2947 MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.width"),
2948 ConstantAsMetadata::get(Simdlen)}));
2949 }
2950
2951 /// Create the TargetMachine object to query the backend for optimization
2952 /// preferences.
2953 ///
2954 /// Ideally, this would be passed from the front-end to the OpenMPBuilder, but
2955 /// e.g. Clang does not pass it to its CodeGen layer and creates it only when
2956 /// needed for the LLVM pass pipline. We use some default options to avoid
2957 /// having to pass too many settings from the frontend that probably do not
2958 /// matter.
2959 ///
2960 /// Currently, TargetMachine is only used sometimes by the unrollLoopPartial
2961 /// method. If we are going to use TargetMachine for more purposes, especially
2962 /// those that are sensitive to TargetOptions, RelocModel and CodeModel, it
2963 /// might become be worth requiring front-ends to pass on their TargetMachine,
2964 /// or at least cache it between methods. Note that while fontends such as Clang
2965 /// have just a single main TargetMachine per translation unit, "target-cpu" and
2966 /// "target-features" that determine the TargetMachine are per-function and can
2967 /// be overrided using __attribute__((target("OPTIONS"))).
2968 static std::unique_ptr<TargetMachine>
createTargetMachine(Function * F,CodeGenOpt::Level OptLevel)2969 createTargetMachine(Function *F, CodeGenOpt::Level OptLevel) {
2970 Module *M = F->getParent();
2971
2972 StringRef CPU = F->getFnAttribute("target-cpu").getValueAsString();
2973 StringRef Features = F->getFnAttribute("target-features").getValueAsString();
2974 const std::string &Triple = M->getTargetTriple();
2975
2976 std::string Error;
2977 const llvm::Target *TheTarget = TargetRegistry::lookupTarget(Triple, Error);
2978 if (!TheTarget)
2979 return {};
2980
2981 llvm::TargetOptions Options;
2982 return std::unique_ptr<TargetMachine>(TheTarget->createTargetMachine(
2983 Triple, CPU, Features, Options, /*RelocModel=*/None, /*CodeModel=*/None,
2984 OptLevel));
2985 }
2986
2987 /// Heuristically determine the best-performant unroll factor for \p CLI. This
2988 /// depends on the target processor. We are re-using the same heuristics as the
2989 /// LoopUnrollPass.
computeHeuristicUnrollFactor(CanonicalLoopInfo * CLI)2990 static int32_t computeHeuristicUnrollFactor(CanonicalLoopInfo *CLI) {
2991 Function *F = CLI->getFunction();
2992
2993 // Assume the user requests the most aggressive unrolling, even if the rest of
2994 // the code is optimized using a lower setting.
2995 CodeGenOpt::Level OptLevel = CodeGenOpt::Aggressive;
2996 std::unique_ptr<TargetMachine> TM = createTargetMachine(F, OptLevel);
2997
2998 FunctionAnalysisManager FAM;
2999 FAM.registerPass([]() { return TargetLibraryAnalysis(); });
3000 FAM.registerPass([]() { return AssumptionAnalysis(); });
3001 FAM.registerPass([]() { return DominatorTreeAnalysis(); });
3002 FAM.registerPass([]() { return LoopAnalysis(); });
3003 FAM.registerPass([]() { return ScalarEvolutionAnalysis(); });
3004 FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
3005 TargetIRAnalysis TIRA;
3006 if (TM)
3007 TIRA = TargetIRAnalysis(
3008 [&](const Function &F) { return TM->getTargetTransformInfo(F); });
3009 FAM.registerPass([&]() { return TIRA; });
3010
3011 TargetIRAnalysis::Result &&TTI = TIRA.run(*F, FAM);
3012 ScalarEvolutionAnalysis SEA;
3013 ScalarEvolution &&SE = SEA.run(*F, FAM);
3014 DominatorTreeAnalysis DTA;
3015 DominatorTree &&DT = DTA.run(*F, FAM);
3016 LoopAnalysis LIA;
3017 LoopInfo &&LI = LIA.run(*F, FAM);
3018 AssumptionAnalysis ACT;
3019 AssumptionCache &&AC = ACT.run(*F, FAM);
3020 OptimizationRemarkEmitter ORE{F};
3021
3022 Loop *L = LI.getLoopFor(CLI->getHeader());
3023 assert(L && "Expecting CanonicalLoopInfo to be recognized as a loop");
3024
3025 TargetTransformInfo::UnrollingPreferences UP =
3026 gatherUnrollingPreferences(L, SE, TTI,
3027 /*BlockFrequencyInfo=*/nullptr,
3028 /*ProfileSummaryInfo=*/nullptr, ORE, OptLevel,
3029 /*UserThreshold=*/None,
3030 /*UserCount=*/None,
3031 /*UserAllowPartial=*/true,
3032 /*UserAllowRuntime=*/true,
3033 /*UserUpperBound=*/None,
3034 /*UserFullUnrollMaxCount=*/None);
3035
3036 UP.Force = true;
3037
3038 // Account for additional optimizations taking place before the LoopUnrollPass
3039 // would unroll the loop.
3040 UP.Threshold *= UnrollThresholdFactor;
3041 UP.PartialThreshold *= UnrollThresholdFactor;
3042
3043 // Use normal unroll factors even if the rest of the code is optimized for
3044 // size.
3045 UP.OptSizeThreshold = UP.Threshold;
3046 UP.PartialOptSizeThreshold = UP.PartialThreshold;
3047
3048 LLVM_DEBUG(dbgs() << "Unroll heuristic thresholds:\n"
3049 << " Threshold=" << UP.Threshold << "\n"
3050 << " PartialThreshold=" << UP.PartialThreshold << "\n"
3051 << " OptSizeThreshold=" << UP.OptSizeThreshold << "\n"
3052 << " PartialOptSizeThreshold="
3053 << UP.PartialOptSizeThreshold << "\n");
3054
3055 // Disable peeling.
3056 TargetTransformInfo::PeelingPreferences PP =
3057 gatherPeelingPreferences(L, SE, TTI,
3058 /*UserAllowPeeling=*/false,
3059 /*UserAllowProfileBasedPeeling=*/false,
3060 /*UnrollingSpecficValues=*/false);
3061
3062 SmallPtrSet<const Value *, 32> EphValues;
3063 CodeMetrics::collectEphemeralValues(L, &AC, EphValues);
3064
3065 // Assume that reads and writes to stack variables can be eliminated by
3066 // Mem2Reg, SROA or LICM. That is, don't count them towards the loop body's
3067 // size.
3068 for (BasicBlock *BB : L->blocks()) {
3069 for (Instruction &I : *BB) {
3070 Value *Ptr;
3071 if (auto *Load = dyn_cast<LoadInst>(&I)) {
3072 Ptr = Load->getPointerOperand();
3073 } else if (auto *Store = dyn_cast<StoreInst>(&I)) {
3074 Ptr = Store->getPointerOperand();
3075 } else
3076 continue;
3077
3078 Ptr = Ptr->stripPointerCasts();
3079
3080 if (auto *Alloca = dyn_cast<AllocaInst>(Ptr)) {
3081 if (Alloca->getParent() == &F->getEntryBlock())
3082 EphValues.insert(&I);
3083 }
3084 }
3085 }
3086
3087 unsigned NumInlineCandidates;
3088 bool NotDuplicatable;
3089 bool Convergent;
3090 InstructionCost LoopSizeIC =
3091 ApproximateLoopSize(L, NumInlineCandidates, NotDuplicatable, Convergent,
3092 TTI, EphValues, UP.BEInsns);
3093 LLVM_DEBUG(dbgs() << "Estimated loop size is " << LoopSizeIC << "\n");
3094
3095 // Loop is not unrollable if the loop contains certain instructions.
3096 if (NotDuplicatable || Convergent || !LoopSizeIC.isValid()) {
3097 LLVM_DEBUG(dbgs() << "Loop not considered unrollable\n");
3098 return 1;
3099 }
3100 unsigned LoopSize = *LoopSizeIC.getValue();
3101
3102 // TODO: Determine trip count of \p CLI if constant, computeUnrollCount might
3103 // be able to use it.
3104 int TripCount = 0;
3105 int MaxTripCount = 0;
3106 bool MaxOrZero = false;
3107 unsigned TripMultiple = 0;
3108
3109 bool UseUpperBound = false;
3110 computeUnrollCount(L, TTI, DT, &LI, SE, EphValues, &ORE, TripCount,
3111 MaxTripCount, MaxOrZero, TripMultiple, LoopSize, UP, PP,
3112 UseUpperBound);
3113 unsigned Factor = UP.Count;
3114 LLVM_DEBUG(dbgs() << "Suggesting unroll factor of " << Factor << "\n");
3115
3116 // This function returns 1 to signal to not unroll a loop.
3117 if (Factor == 0)
3118 return 1;
3119 return Factor;
3120 }
3121
unrollLoopPartial(DebugLoc DL,CanonicalLoopInfo * Loop,int32_t Factor,CanonicalLoopInfo ** UnrolledCLI)3122 void OpenMPIRBuilder::unrollLoopPartial(DebugLoc DL, CanonicalLoopInfo *Loop,
3123 int32_t Factor,
3124 CanonicalLoopInfo **UnrolledCLI) {
3125 assert(Factor >= 0 && "Unroll factor must not be negative");
3126
3127 Function *F = Loop->getFunction();
3128 LLVMContext &Ctx = F->getContext();
3129
3130 // If the unrolled loop is not used for another loop-associated directive, it
3131 // is sufficient to add metadata for the LoopUnrollPass.
3132 if (!UnrolledCLI) {
3133 SmallVector<Metadata *, 2> LoopMetadata;
3134 LoopMetadata.push_back(
3135 MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")));
3136
3137 if (Factor >= 1) {
3138 ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
3139 ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
3140 LoopMetadata.push_back(MDNode::get(
3141 Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst}));
3142 }
3143
3144 addLoopMetadata(Loop, LoopMetadata);
3145 return;
3146 }
3147
3148 // Heuristically determine the unroll factor.
3149 if (Factor == 0)
3150 Factor = computeHeuristicUnrollFactor(Loop);
3151
3152 // No change required with unroll factor 1.
3153 if (Factor == 1) {
3154 *UnrolledCLI = Loop;
3155 return;
3156 }
3157
3158 assert(Factor >= 2 &&
3159 "unrolling only makes sense with a factor of 2 or larger");
3160
3161 Type *IndVarTy = Loop->getIndVarType();
3162
3163 // Apply partial unrolling by tiling the loop by the unroll-factor, then fully
3164 // unroll the inner loop.
3165 Value *FactorVal =
3166 ConstantInt::get(IndVarTy, APInt(IndVarTy->getIntegerBitWidth(), Factor,
3167 /*isSigned=*/false));
3168 std::vector<CanonicalLoopInfo *> LoopNest =
3169 tileLoops(DL, {Loop}, {FactorVal});
3170 assert(LoopNest.size() == 2 && "Expect 2 loops after tiling");
3171 *UnrolledCLI = LoopNest[0];
3172 CanonicalLoopInfo *InnerLoop = LoopNest[1];
3173
3174 // LoopUnrollPass can only fully unroll loops with constant trip count.
3175 // Unroll by the unroll factor with a fallback epilog for the remainder
3176 // iterations if necessary.
3177 ConstantAsMetadata *FactorConst = ConstantAsMetadata::get(
3178 ConstantInt::get(Type::getInt32Ty(Ctx), APInt(32, Factor)));
3179 addLoopMetadata(
3180 InnerLoop,
3181 {MDNode::get(Ctx, MDString::get(Ctx, "llvm.loop.unroll.enable")),
3182 MDNode::get(
3183 Ctx, {MDString::get(Ctx, "llvm.loop.unroll.count"), FactorConst})});
3184
3185 #ifndef NDEBUG
3186 (*UnrolledCLI)->assertOK();
3187 #endif
3188 }
3189
3190 OpenMPIRBuilder::InsertPointTy
createCopyPrivate(const LocationDescription & Loc,llvm::Value * BufSize,llvm::Value * CpyBuf,llvm::Value * CpyFn,llvm::Value * DidIt)3191 OpenMPIRBuilder::createCopyPrivate(const LocationDescription &Loc,
3192 llvm::Value *BufSize, llvm::Value *CpyBuf,
3193 llvm::Value *CpyFn, llvm::Value *DidIt) {
3194 if (!updateToLocation(Loc))
3195 return Loc.IP;
3196
3197 uint32_t SrcLocStrSize;
3198 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3199 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3200 Value *ThreadId = getOrCreateThreadID(Ident);
3201
3202 llvm::Value *DidItLD = Builder.CreateLoad(Builder.getInt32Ty(), DidIt);
3203
3204 Value *Args[] = {Ident, ThreadId, BufSize, CpyBuf, CpyFn, DidItLD};
3205
3206 Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_copyprivate);
3207 Builder.CreateCall(Fn, Args);
3208
3209 return Builder.saveIP();
3210 }
3211
createSingle(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool IsNowait,llvm::Value * DidIt)3212 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createSingle(
3213 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
3214 FinalizeCallbackTy FiniCB, bool IsNowait, llvm::Value *DidIt) {
3215
3216 if (!updateToLocation(Loc))
3217 return Loc.IP;
3218
3219 // If needed (i.e. not null), initialize `DidIt` with 0
3220 if (DidIt) {
3221 Builder.CreateStore(Builder.getInt32(0), DidIt);
3222 }
3223
3224 Directive OMPD = Directive::OMPD_single;
3225 uint32_t SrcLocStrSize;
3226 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3227 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3228 Value *ThreadId = getOrCreateThreadID(Ident);
3229 Value *Args[] = {Ident, ThreadId};
3230
3231 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_single);
3232 Instruction *EntryCall = Builder.CreateCall(EntryRTLFn, Args);
3233
3234 Function *ExitRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_single);
3235 Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
3236
3237 // generates the following:
3238 // if (__kmpc_single()) {
3239 // .... single region ...
3240 // __kmpc_end_single
3241 // }
3242 // __kmpc_barrier
3243
3244 EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
3245 /*Conditional*/ true,
3246 /*hasFinalize*/ true);
3247 if (!IsNowait)
3248 createBarrier(LocationDescription(Builder.saveIP(), Loc.DL),
3249 omp::Directive::OMPD_unknown, /* ForceSimpleCall */ false,
3250 /* CheckCancelFlag */ false);
3251 return Builder.saveIP();
3252 }
3253
createCritical(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,StringRef CriticalName,Value * HintInst)3254 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCritical(
3255 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
3256 FinalizeCallbackTy FiniCB, StringRef CriticalName, Value *HintInst) {
3257
3258 if (!updateToLocation(Loc))
3259 return Loc.IP;
3260
3261 Directive OMPD = Directive::OMPD_critical;
3262 uint32_t SrcLocStrSize;
3263 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3264 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3265 Value *ThreadId = getOrCreateThreadID(Ident);
3266 Value *LockVar = getOMPCriticalRegionLock(CriticalName);
3267 Value *Args[] = {Ident, ThreadId, LockVar};
3268
3269 SmallVector<llvm::Value *, 4> EnterArgs(std::begin(Args), std::end(Args));
3270 Function *RTFn = nullptr;
3271 if (HintInst) {
3272 // Add Hint to entry Args and create call
3273 EnterArgs.push_back(HintInst);
3274 RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical_with_hint);
3275 } else {
3276 RTFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_critical);
3277 }
3278 Instruction *EntryCall = Builder.CreateCall(RTFn, EnterArgs);
3279
3280 Function *ExitRTLFn =
3281 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_critical);
3282 Instruction *ExitCall = Builder.CreateCall(ExitRTLFn, Args);
3283
3284 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
3285 /*Conditional*/ false, /*hasFinalize*/ true);
3286 }
3287
3288 OpenMPIRBuilder::InsertPointTy
createOrderedDepend(const LocationDescription & Loc,InsertPointTy AllocaIP,unsigned NumLoops,ArrayRef<llvm::Value * > StoreValues,const Twine & Name,bool IsDependSource)3289 OpenMPIRBuilder::createOrderedDepend(const LocationDescription &Loc,
3290 InsertPointTy AllocaIP, unsigned NumLoops,
3291 ArrayRef<llvm::Value *> StoreValues,
3292 const Twine &Name, bool IsDependSource) {
3293 for (size_t I = 0; I < StoreValues.size(); I++)
3294 assert(StoreValues[I]->getType()->isIntegerTy(64) &&
3295 "OpenMP runtime requires depend vec with i64 type");
3296
3297 if (!updateToLocation(Loc))
3298 return Loc.IP;
3299
3300 // Allocate space for vector and generate alloc instruction.
3301 auto *ArrI64Ty = ArrayType::get(Int64, NumLoops);
3302 Builder.restoreIP(AllocaIP);
3303 AllocaInst *ArgsBase = Builder.CreateAlloca(ArrI64Ty, nullptr, Name);
3304 ArgsBase->setAlignment(Align(8));
3305 Builder.restoreIP(Loc.IP);
3306
3307 // Store the index value with offset in depend vector.
3308 for (unsigned I = 0; I < NumLoops; ++I) {
3309 Value *DependAddrGEPIter = Builder.CreateInBoundsGEP(
3310 ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(I)});
3311 StoreInst *STInst = Builder.CreateStore(StoreValues[I], DependAddrGEPIter);
3312 STInst->setAlignment(Align(8));
3313 }
3314
3315 Value *DependBaseAddrGEP = Builder.CreateInBoundsGEP(
3316 ArrI64Ty, ArgsBase, {Builder.getInt64(0), Builder.getInt64(0)});
3317
3318 uint32_t SrcLocStrSize;
3319 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3320 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3321 Value *ThreadId = getOrCreateThreadID(Ident);
3322 Value *Args[] = {Ident, ThreadId, DependBaseAddrGEP};
3323
3324 Function *RTLFn = nullptr;
3325 if (IsDependSource)
3326 RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_post);
3327 else
3328 RTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_doacross_wait);
3329 Builder.CreateCall(RTLFn, Args);
3330
3331 return Builder.saveIP();
3332 }
3333
createOrderedThreadsSimd(const LocationDescription & Loc,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool IsThreads)3334 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createOrderedThreadsSimd(
3335 const LocationDescription &Loc, BodyGenCallbackTy BodyGenCB,
3336 FinalizeCallbackTy FiniCB, bool IsThreads) {
3337 if (!updateToLocation(Loc))
3338 return Loc.IP;
3339
3340 Directive OMPD = Directive::OMPD_ordered;
3341 Instruction *EntryCall = nullptr;
3342 Instruction *ExitCall = nullptr;
3343
3344 if (IsThreads) {
3345 uint32_t SrcLocStrSize;
3346 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3347 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3348 Value *ThreadId = getOrCreateThreadID(Ident);
3349 Value *Args[] = {Ident, ThreadId};
3350
3351 Function *EntryRTLFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_ordered);
3352 EntryCall = Builder.CreateCall(EntryRTLFn, Args);
3353
3354 Function *ExitRTLFn =
3355 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_ordered);
3356 ExitCall = Builder.CreateCall(ExitRTLFn, Args);
3357 }
3358
3359 return EmitOMPInlinedRegion(OMPD, EntryCall, ExitCall, BodyGenCB, FiniCB,
3360 /*Conditional*/ false, /*hasFinalize*/ true);
3361 }
3362
EmitOMPInlinedRegion(Directive OMPD,Instruction * EntryCall,Instruction * ExitCall,BodyGenCallbackTy BodyGenCB,FinalizeCallbackTy FiniCB,bool Conditional,bool HasFinalize,bool IsCancellable)3363 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::EmitOMPInlinedRegion(
3364 Directive OMPD, Instruction *EntryCall, Instruction *ExitCall,
3365 BodyGenCallbackTy BodyGenCB, FinalizeCallbackTy FiniCB, bool Conditional,
3366 bool HasFinalize, bool IsCancellable) {
3367
3368 if (HasFinalize)
3369 FinalizationStack.push_back({FiniCB, OMPD, IsCancellable});
3370
3371 // Create inlined region's entry and body blocks, in preparation
3372 // for conditional creation
3373 BasicBlock *EntryBB = Builder.GetInsertBlock();
3374 Instruction *SplitPos = EntryBB->getTerminator();
3375 if (!isa_and_nonnull<BranchInst>(SplitPos))
3376 SplitPos = new UnreachableInst(Builder.getContext(), EntryBB);
3377 BasicBlock *ExitBB = EntryBB->splitBasicBlock(SplitPos, "omp_region.end");
3378 BasicBlock *FiniBB =
3379 EntryBB->splitBasicBlock(EntryBB->getTerminator(), "omp_region.finalize");
3380
3381 Builder.SetInsertPoint(EntryBB->getTerminator());
3382 emitCommonDirectiveEntry(OMPD, EntryCall, ExitBB, Conditional);
3383
3384 // generate body
3385 BodyGenCB(/* AllocaIP */ InsertPointTy(),
3386 /* CodeGenIP */ Builder.saveIP());
3387
3388 // emit exit call and do any needed finalization.
3389 auto FinIP = InsertPointTy(FiniBB, FiniBB->getFirstInsertionPt());
3390 assert(FiniBB->getTerminator()->getNumSuccessors() == 1 &&
3391 FiniBB->getTerminator()->getSuccessor(0) == ExitBB &&
3392 "Unexpected control flow graph state!!");
3393 emitCommonDirectiveExit(OMPD, FinIP, ExitCall, HasFinalize);
3394 assert(FiniBB->getUniquePredecessor()->getUniqueSuccessor() == FiniBB &&
3395 "Unexpected Control Flow State!");
3396 MergeBlockIntoPredecessor(FiniBB);
3397
3398 // If we are skipping the region of a non conditional, remove the exit
3399 // block, and clear the builder's insertion point.
3400 assert(SplitPos->getParent() == ExitBB &&
3401 "Unexpected Insertion point location!");
3402 auto merged = MergeBlockIntoPredecessor(ExitBB);
3403 BasicBlock *ExitPredBB = SplitPos->getParent();
3404 auto InsertBB = merged ? ExitPredBB : ExitBB;
3405 if (!isa_and_nonnull<BranchInst>(SplitPos))
3406 SplitPos->eraseFromParent();
3407 Builder.SetInsertPoint(InsertBB);
3408
3409 return Builder.saveIP();
3410 }
3411
emitCommonDirectiveEntry(Directive OMPD,Value * EntryCall,BasicBlock * ExitBB,bool Conditional)3412 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveEntry(
3413 Directive OMPD, Value *EntryCall, BasicBlock *ExitBB, bool Conditional) {
3414 // if nothing to do, Return current insertion point.
3415 if (!Conditional || !EntryCall)
3416 return Builder.saveIP();
3417
3418 BasicBlock *EntryBB = Builder.GetInsertBlock();
3419 Value *CallBool = Builder.CreateIsNotNull(EntryCall);
3420 auto *ThenBB = BasicBlock::Create(M.getContext(), "omp_region.body");
3421 auto *UI = new UnreachableInst(Builder.getContext(), ThenBB);
3422
3423 // Emit thenBB and set the Builder's insertion point there for
3424 // body generation next. Place the block after the current block.
3425 Function *CurFn = EntryBB->getParent();
3426 CurFn->getBasicBlockList().insertAfter(EntryBB->getIterator(), ThenBB);
3427
3428 // Move Entry branch to end of ThenBB, and replace with conditional
3429 // branch (If-stmt)
3430 Instruction *EntryBBTI = EntryBB->getTerminator();
3431 Builder.CreateCondBr(CallBool, ThenBB, ExitBB);
3432 EntryBBTI->removeFromParent();
3433 Builder.SetInsertPoint(UI);
3434 Builder.Insert(EntryBBTI);
3435 UI->eraseFromParent();
3436 Builder.SetInsertPoint(ThenBB->getTerminator());
3437
3438 // return an insertion point to ExitBB.
3439 return IRBuilder<>::InsertPoint(ExitBB, ExitBB->getFirstInsertionPt());
3440 }
3441
emitCommonDirectiveExit(omp::Directive OMPD,InsertPointTy FinIP,Instruction * ExitCall,bool HasFinalize)3442 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitCommonDirectiveExit(
3443 omp::Directive OMPD, InsertPointTy FinIP, Instruction *ExitCall,
3444 bool HasFinalize) {
3445
3446 Builder.restoreIP(FinIP);
3447
3448 // If there is finalization to do, emit it before the exit call
3449 if (HasFinalize) {
3450 assert(!FinalizationStack.empty() &&
3451 "Unexpected finalization stack state!");
3452
3453 FinalizationInfo Fi = FinalizationStack.pop_back_val();
3454 assert(Fi.DK == OMPD && "Unexpected Directive for Finalization call!");
3455
3456 Fi.FiniCB(FinIP);
3457
3458 BasicBlock *FiniBB = FinIP.getBlock();
3459 Instruction *FiniBBTI = FiniBB->getTerminator();
3460
3461 // set Builder IP for call creation
3462 Builder.SetInsertPoint(FiniBBTI);
3463 }
3464
3465 if (!ExitCall)
3466 return Builder.saveIP();
3467
3468 // place the Exitcall as last instruction before Finalization block terminator
3469 ExitCall->removeFromParent();
3470 Builder.Insert(ExitCall);
3471
3472 return IRBuilder<>::InsertPoint(ExitCall->getParent(),
3473 ExitCall->getIterator());
3474 }
3475
createCopyinClauseBlocks(InsertPointTy IP,Value * MasterAddr,Value * PrivateAddr,llvm::IntegerType * IntPtrTy,bool BranchtoEnd)3476 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createCopyinClauseBlocks(
3477 InsertPointTy IP, Value *MasterAddr, Value *PrivateAddr,
3478 llvm::IntegerType *IntPtrTy, bool BranchtoEnd) {
3479 if (!IP.isSet())
3480 return IP;
3481
3482 IRBuilder<>::InsertPointGuard IPG(Builder);
3483
3484 // creates the following CFG structure
3485 // OMP_Entry : (MasterAddr != PrivateAddr)?
3486 // F T
3487 // | \
3488 // | copin.not.master
3489 // | /
3490 // v /
3491 // copyin.not.master.end
3492 // |
3493 // v
3494 // OMP.Entry.Next
3495
3496 BasicBlock *OMP_Entry = IP.getBlock();
3497 Function *CurFn = OMP_Entry->getParent();
3498 BasicBlock *CopyBegin =
3499 BasicBlock::Create(M.getContext(), "copyin.not.master", CurFn);
3500 BasicBlock *CopyEnd = nullptr;
3501
3502 // If entry block is terminated, split to preserve the branch to following
3503 // basic block (i.e. OMP.Entry.Next), otherwise, leave everything as is.
3504 if (isa_and_nonnull<BranchInst>(OMP_Entry->getTerminator())) {
3505 CopyEnd = OMP_Entry->splitBasicBlock(OMP_Entry->getTerminator(),
3506 "copyin.not.master.end");
3507 OMP_Entry->getTerminator()->eraseFromParent();
3508 } else {
3509 CopyEnd =
3510 BasicBlock::Create(M.getContext(), "copyin.not.master.end", CurFn);
3511 }
3512
3513 Builder.SetInsertPoint(OMP_Entry);
3514 Value *MasterPtr = Builder.CreatePtrToInt(MasterAddr, IntPtrTy);
3515 Value *PrivatePtr = Builder.CreatePtrToInt(PrivateAddr, IntPtrTy);
3516 Value *cmp = Builder.CreateICmpNE(MasterPtr, PrivatePtr);
3517 Builder.CreateCondBr(cmp, CopyBegin, CopyEnd);
3518
3519 Builder.SetInsertPoint(CopyBegin);
3520 if (BranchtoEnd)
3521 Builder.SetInsertPoint(Builder.CreateBr(CopyEnd));
3522
3523 return Builder.saveIP();
3524 }
3525
createOMPAlloc(const LocationDescription & Loc,Value * Size,Value * Allocator,std::string Name)3526 CallInst *OpenMPIRBuilder::createOMPAlloc(const LocationDescription &Loc,
3527 Value *Size, Value *Allocator,
3528 std::string Name) {
3529 IRBuilder<>::InsertPointGuard IPG(Builder);
3530 Builder.restoreIP(Loc.IP);
3531
3532 uint32_t SrcLocStrSize;
3533 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3534 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3535 Value *ThreadId = getOrCreateThreadID(Ident);
3536 Value *Args[] = {ThreadId, Size, Allocator};
3537
3538 Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_alloc);
3539
3540 return Builder.CreateCall(Fn, Args, Name);
3541 }
3542
createOMPFree(const LocationDescription & Loc,Value * Addr,Value * Allocator,std::string Name)3543 CallInst *OpenMPIRBuilder::createOMPFree(const LocationDescription &Loc,
3544 Value *Addr, Value *Allocator,
3545 std::string Name) {
3546 IRBuilder<>::InsertPointGuard IPG(Builder);
3547 Builder.restoreIP(Loc.IP);
3548
3549 uint32_t SrcLocStrSize;
3550 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3551 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3552 Value *ThreadId = getOrCreateThreadID(Ident);
3553 Value *Args[] = {ThreadId, Addr, Allocator};
3554 Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_free);
3555 return Builder.CreateCall(Fn, Args, Name);
3556 }
3557
createOMPInteropInit(const LocationDescription & Loc,Value * InteropVar,omp::OMPInteropType InteropType,Value * Device,Value * NumDependences,Value * DependenceAddress,bool HaveNowaitClause)3558 CallInst *OpenMPIRBuilder::createOMPInteropInit(
3559 const LocationDescription &Loc, Value *InteropVar,
3560 omp::OMPInteropType InteropType, Value *Device, Value *NumDependences,
3561 Value *DependenceAddress, bool HaveNowaitClause) {
3562 IRBuilder<>::InsertPointGuard IPG(Builder);
3563 Builder.restoreIP(Loc.IP);
3564
3565 uint32_t SrcLocStrSize;
3566 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3567 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3568 Value *ThreadId = getOrCreateThreadID(Ident);
3569 if (Device == nullptr)
3570 Device = ConstantInt::get(Int32, -1);
3571 Constant *InteropTypeVal = ConstantInt::get(Int64, (int)InteropType);
3572 if (NumDependences == nullptr) {
3573 NumDependences = ConstantInt::get(Int32, 0);
3574 PointerType *PointerTypeVar = Type::getInt8PtrTy(M.getContext());
3575 DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
3576 }
3577 Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
3578 Value *Args[] = {
3579 Ident, ThreadId, InteropVar, InteropTypeVal,
3580 Device, NumDependences, DependenceAddress, HaveNowaitClauseVal};
3581
3582 Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_init);
3583
3584 return Builder.CreateCall(Fn, Args);
3585 }
3586
createOMPInteropDestroy(const LocationDescription & Loc,Value * InteropVar,Value * Device,Value * NumDependences,Value * DependenceAddress,bool HaveNowaitClause)3587 CallInst *OpenMPIRBuilder::createOMPInteropDestroy(
3588 const LocationDescription &Loc, Value *InteropVar, Value *Device,
3589 Value *NumDependences, Value *DependenceAddress, bool HaveNowaitClause) {
3590 IRBuilder<>::InsertPointGuard IPG(Builder);
3591 Builder.restoreIP(Loc.IP);
3592
3593 uint32_t SrcLocStrSize;
3594 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3595 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3596 Value *ThreadId = getOrCreateThreadID(Ident);
3597 if (Device == nullptr)
3598 Device = ConstantInt::get(Int32, -1);
3599 if (NumDependences == nullptr) {
3600 NumDependences = ConstantInt::get(Int32, 0);
3601 PointerType *PointerTypeVar = Type::getInt8PtrTy(M.getContext());
3602 DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
3603 }
3604 Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
3605 Value *Args[] = {
3606 Ident, ThreadId, InteropVar, Device,
3607 NumDependences, DependenceAddress, HaveNowaitClauseVal};
3608
3609 Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_destroy);
3610
3611 return Builder.CreateCall(Fn, Args);
3612 }
3613
createOMPInteropUse(const LocationDescription & Loc,Value * InteropVar,Value * Device,Value * NumDependences,Value * DependenceAddress,bool HaveNowaitClause)3614 CallInst *OpenMPIRBuilder::createOMPInteropUse(const LocationDescription &Loc,
3615 Value *InteropVar, Value *Device,
3616 Value *NumDependences,
3617 Value *DependenceAddress,
3618 bool HaveNowaitClause) {
3619 IRBuilder<>::InsertPointGuard IPG(Builder);
3620 Builder.restoreIP(Loc.IP);
3621 uint32_t SrcLocStrSize;
3622 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3623 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3624 Value *ThreadId = getOrCreateThreadID(Ident);
3625 if (Device == nullptr)
3626 Device = ConstantInt::get(Int32, -1);
3627 if (NumDependences == nullptr) {
3628 NumDependences = ConstantInt::get(Int32, 0);
3629 PointerType *PointerTypeVar = Type::getInt8PtrTy(M.getContext());
3630 DependenceAddress = ConstantPointerNull::get(PointerTypeVar);
3631 }
3632 Value *HaveNowaitClauseVal = ConstantInt::get(Int32, HaveNowaitClause);
3633 Value *Args[] = {
3634 Ident, ThreadId, InteropVar, Device,
3635 NumDependences, DependenceAddress, HaveNowaitClauseVal};
3636
3637 Function *Fn = getOrCreateRuntimeFunctionPtr(OMPRTL___tgt_interop_use);
3638
3639 return Builder.CreateCall(Fn, Args);
3640 }
3641
createCachedThreadPrivate(const LocationDescription & Loc,llvm::Value * Pointer,llvm::ConstantInt * Size,const llvm::Twine & Name)3642 CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
3643 const LocationDescription &Loc, llvm::Value *Pointer,
3644 llvm::ConstantInt *Size, const llvm::Twine &Name) {
3645 IRBuilder<>::InsertPointGuard IPG(Builder);
3646 Builder.restoreIP(Loc.IP);
3647
3648 uint32_t SrcLocStrSize;
3649 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3650 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3651 Value *ThreadId = getOrCreateThreadID(Ident);
3652 Constant *ThreadPrivateCache =
3653 getOrCreateOMPInternalVariable(Int8PtrPtr, Name);
3654 llvm::Value *Args[] = {Ident, ThreadId, Pointer, Size, ThreadPrivateCache};
3655
3656 Function *Fn =
3657 getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_threadprivate_cached);
3658
3659 return Builder.CreateCall(Fn, Args);
3660 }
3661
3662 OpenMPIRBuilder::InsertPointTy
createTargetInit(const LocationDescription & Loc,bool IsSPMD,bool RequiresFullRuntime)3663 OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
3664 bool RequiresFullRuntime) {
3665 if (!updateToLocation(Loc))
3666 return Loc.IP;
3667
3668 uint32_t SrcLocStrSize;
3669 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3670 Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3671 ConstantInt *IsSPMDVal = ConstantInt::getSigned(
3672 IntegerType::getInt8Ty(Int8->getContext()),
3673 IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
3674 ConstantInt *UseGenericStateMachine =
3675 ConstantInt::getBool(Int32->getContext(), !IsSPMD);
3676 ConstantInt *RequiresFullRuntimeVal =
3677 ConstantInt::getBool(Int32->getContext(), RequiresFullRuntime);
3678
3679 Function *Fn = getOrCreateRuntimeFunctionPtr(
3680 omp::RuntimeFunction::OMPRTL___kmpc_target_init);
3681
3682 CallInst *ThreadKind = Builder.CreateCall(
3683 Fn, {Ident, IsSPMDVal, UseGenericStateMachine, RequiresFullRuntimeVal});
3684
3685 Value *ExecUserCode = Builder.CreateICmpEQ(
3686 ThreadKind, ConstantInt::get(ThreadKind->getType(), -1),
3687 "exec_user_code");
3688
3689 // ThreadKind = __kmpc_target_init(...)
3690 // if (ThreadKind == -1)
3691 // user_code
3692 // else
3693 // return;
3694
3695 auto *UI = Builder.CreateUnreachable();
3696 BasicBlock *CheckBB = UI->getParent();
3697 BasicBlock *UserCodeEntryBB = CheckBB->splitBasicBlock(UI, "user_code.entry");
3698
3699 BasicBlock *WorkerExitBB = BasicBlock::Create(
3700 CheckBB->getContext(), "worker.exit", CheckBB->getParent());
3701 Builder.SetInsertPoint(WorkerExitBB);
3702 Builder.CreateRetVoid();
3703
3704 auto *CheckBBTI = CheckBB->getTerminator();
3705 Builder.SetInsertPoint(CheckBBTI);
3706 Builder.CreateCondBr(ExecUserCode, UI->getParent(), WorkerExitBB);
3707
3708 CheckBBTI->eraseFromParent();
3709 UI->eraseFromParent();
3710
3711 // Continue in the "user_code" block, see diagram above and in
3712 // openmp/libomptarget/deviceRTLs/common/include/target.h .
3713 return InsertPointTy(UserCodeEntryBB, UserCodeEntryBB->getFirstInsertionPt());
3714 }
3715
createTargetDeinit(const LocationDescription & Loc,bool IsSPMD,bool RequiresFullRuntime)3716 void OpenMPIRBuilder::createTargetDeinit(const LocationDescription &Loc,
3717 bool IsSPMD,
3718 bool RequiresFullRuntime) {
3719 if (!updateToLocation(Loc))
3720 return;
3721
3722 uint32_t SrcLocStrSize;
3723 Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
3724 Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
3725 ConstantInt *IsSPMDVal = ConstantInt::getSigned(
3726 IntegerType::getInt8Ty(Int8->getContext()),
3727 IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
3728 ConstantInt *RequiresFullRuntimeVal =
3729 ConstantInt::getBool(Int32->getContext(), RequiresFullRuntime);
3730
3731 Function *Fn = getOrCreateRuntimeFunctionPtr(
3732 omp::RuntimeFunction::OMPRTL___kmpc_target_deinit);
3733
3734 Builder.CreateCall(Fn, {Ident, IsSPMDVal, RequiresFullRuntimeVal});
3735 }
3736
getNameWithSeparators(ArrayRef<StringRef> Parts,StringRef FirstSeparator,StringRef Separator)3737 std::string OpenMPIRBuilder::getNameWithSeparators(ArrayRef<StringRef> Parts,
3738 StringRef FirstSeparator,
3739 StringRef Separator) {
3740 SmallString<128> Buffer;
3741 llvm::raw_svector_ostream OS(Buffer);
3742 StringRef Sep = FirstSeparator;
3743 for (StringRef Part : Parts) {
3744 OS << Sep << Part;
3745 Sep = Separator;
3746 }
3747 return OS.str().str();
3748 }
3749
getOrCreateOMPInternalVariable(llvm::Type * Ty,const llvm::Twine & Name,unsigned AddressSpace)3750 Constant *OpenMPIRBuilder::getOrCreateOMPInternalVariable(
3751 llvm::Type *Ty, const llvm::Twine &Name, unsigned AddressSpace) {
3752 // TODO: Replace the twine arg with stringref to get rid of the conversion
3753 // logic. However This is taken from current implementation in clang as is.
3754 // Since this method is used in many places exclusively for OMP internal use
3755 // we will keep it as is for temporarily until we move all users to the
3756 // builder and then, if possible, fix it everywhere in one go.
3757 SmallString<256> Buffer;
3758 llvm::raw_svector_ostream Out(Buffer);
3759 Out << Name;
3760 StringRef RuntimeName = Out.str();
3761 auto &Elem = *InternalVars.try_emplace(RuntimeName, nullptr).first;
3762 if (Elem.second) {
3763 assert(cast<PointerType>(Elem.second->getType())
3764 ->isOpaqueOrPointeeTypeMatches(Ty) &&
3765 "OMP internal variable has different type than requested");
3766 } else {
3767 // TODO: investigate the appropriate linkage type used for the global
3768 // variable for possibly changing that to internal or private, or maybe
3769 // create different versions of the function for different OMP internal
3770 // variables.
3771 Elem.second = new llvm::GlobalVariable(
3772 M, Ty, /*IsConstant*/ false, llvm::GlobalValue::CommonLinkage,
3773 llvm::Constant::getNullValue(Ty), Elem.first(),
3774 /*InsertBefore=*/nullptr, llvm::GlobalValue::NotThreadLocal,
3775 AddressSpace);
3776 }
3777
3778 return Elem.second;
3779 }
3780
getOMPCriticalRegionLock(StringRef CriticalName)3781 Value *OpenMPIRBuilder::getOMPCriticalRegionLock(StringRef CriticalName) {
3782 std::string Prefix = Twine("gomp_critical_user_", CriticalName).str();
3783 std::string Name = getNameWithSeparators({Prefix, "var"}, ".", ".");
3784 return getOrCreateOMPInternalVariable(KmpCriticalNameTy, Name);
3785 }
3786
3787 GlobalVariable *
createOffloadMaptypes(SmallVectorImpl<uint64_t> & Mappings,std::string VarName)3788 OpenMPIRBuilder::createOffloadMaptypes(SmallVectorImpl<uint64_t> &Mappings,
3789 std::string VarName) {
3790 llvm::Constant *MaptypesArrayInit =
3791 llvm::ConstantDataArray::get(M.getContext(), Mappings);
3792 auto *MaptypesArrayGlobal = new llvm::GlobalVariable(
3793 M, MaptypesArrayInit->getType(),
3794 /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MaptypesArrayInit,
3795 VarName);
3796 MaptypesArrayGlobal->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::Global);
3797 return MaptypesArrayGlobal;
3798 }
3799
createMapperAllocas(const LocationDescription & Loc,InsertPointTy AllocaIP,unsigned NumOperands,struct MapperAllocas & MapperAllocas)3800 void OpenMPIRBuilder::createMapperAllocas(const LocationDescription &Loc,
3801 InsertPointTy AllocaIP,
3802 unsigned NumOperands,
3803 struct MapperAllocas &MapperAllocas) {
3804 if (!updateToLocation(Loc))
3805 return;
3806
3807 auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
3808 auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
3809 Builder.restoreIP(AllocaIP);
3810 AllocaInst *ArgsBase = Builder.CreateAlloca(ArrI8PtrTy);
3811 AllocaInst *Args = Builder.CreateAlloca(ArrI8PtrTy);
3812 AllocaInst *ArgSizes = Builder.CreateAlloca(ArrI64Ty);
3813 Builder.restoreIP(Loc.IP);
3814 MapperAllocas.ArgsBase = ArgsBase;
3815 MapperAllocas.Args = Args;
3816 MapperAllocas.ArgSizes = ArgSizes;
3817 }
3818
emitMapperCall(const LocationDescription & Loc,Function * MapperFunc,Value * SrcLocInfo,Value * MaptypesArg,Value * MapnamesArg,struct MapperAllocas & MapperAllocas,int64_t DeviceID,unsigned NumOperands)3819 void OpenMPIRBuilder::emitMapperCall(const LocationDescription &Loc,
3820 Function *MapperFunc, Value *SrcLocInfo,
3821 Value *MaptypesArg, Value *MapnamesArg,
3822 struct MapperAllocas &MapperAllocas,
3823 int64_t DeviceID, unsigned NumOperands) {
3824 if (!updateToLocation(Loc))
3825 return;
3826
3827 auto *ArrI8PtrTy = ArrayType::get(Int8Ptr, NumOperands);
3828 auto *ArrI64Ty = ArrayType::get(Int64, NumOperands);
3829 Value *ArgsBaseGEP =
3830 Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.ArgsBase,
3831 {Builder.getInt32(0), Builder.getInt32(0)});
3832 Value *ArgsGEP =
3833 Builder.CreateInBoundsGEP(ArrI8PtrTy, MapperAllocas.Args,
3834 {Builder.getInt32(0), Builder.getInt32(0)});
3835 Value *ArgSizesGEP =
3836 Builder.CreateInBoundsGEP(ArrI64Ty, MapperAllocas.ArgSizes,
3837 {Builder.getInt32(0), Builder.getInt32(0)});
3838 Value *NullPtr = Constant::getNullValue(Int8Ptr->getPointerTo());
3839 Builder.CreateCall(MapperFunc,
3840 {SrcLocInfo, Builder.getInt64(DeviceID),
3841 Builder.getInt32(NumOperands), ArgsBaseGEP, ArgsGEP,
3842 ArgSizesGEP, MaptypesArg, MapnamesArg, NullPtr});
3843 }
3844
checkAndEmitFlushAfterAtomic(const LocationDescription & Loc,llvm::AtomicOrdering AO,AtomicKind AK)3845 bool OpenMPIRBuilder::checkAndEmitFlushAfterAtomic(
3846 const LocationDescription &Loc, llvm::AtomicOrdering AO, AtomicKind AK) {
3847 assert(!(AO == AtomicOrdering::NotAtomic ||
3848 AO == llvm::AtomicOrdering::Unordered) &&
3849 "Unexpected Atomic Ordering.");
3850
3851 bool Flush = false;
3852 llvm::AtomicOrdering FlushAO = AtomicOrdering::Monotonic;
3853
3854 switch (AK) {
3855 case Read:
3856 if (AO == AtomicOrdering::Acquire || AO == AtomicOrdering::AcquireRelease ||
3857 AO == AtomicOrdering::SequentiallyConsistent) {
3858 FlushAO = AtomicOrdering::Acquire;
3859 Flush = true;
3860 }
3861 break;
3862 case Write:
3863 case Compare:
3864 case Update:
3865 if (AO == AtomicOrdering::Release || AO == AtomicOrdering::AcquireRelease ||
3866 AO == AtomicOrdering::SequentiallyConsistent) {
3867 FlushAO = AtomicOrdering::Release;
3868 Flush = true;
3869 }
3870 break;
3871 case Capture:
3872 switch (AO) {
3873 case AtomicOrdering::Acquire:
3874 FlushAO = AtomicOrdering::Acquire;
3875 Flush = true;
3876 break;
3877 case AtomicOrdering::Release:
3878 FlushAO = AtomicOrdering::Release;
3879 Flush = true;
3880 break;
3881 case AtomicOrdering::AcquireRelease:
3882 case AtomicOrdering::SequentiallyConsistent:
3883 FlushAO = AtomicOrdering::AcquireRelease;
3884 Flush = true;
3885 break;
3886 default:
3887 // do nothing - leave silently.
3888 break;
3889 }
3890 }
3891
3892 if (Flush) {
3893 // Currently Flush RT call still doesn't take memory_ordering, so for when
3894 // that happens, this tries to do the resolution of which atomic ordering
3895 // to use with but issue the flush call
3896 // TODO: pass `FlushAO` after memory ordering support is added
3897 (void)FlushAO;
3898 emitFlush(Loc);
3899 }
3900
3901 // for AO == AtomicOrdering::Monotonic and all other case combinations
3902 // do nothing
3903 return Flush;
3904 }
3905
3906 OpenMPIRBuilder::InsertPointTy
createAtomicRead(const LocationDescription & Loc,AtomicOpValue & X,AtomicOpValue & V,AtomicOrdering AO)3907 OpenMPIRBuilder::createAtomicRead(const LocationDescription &Loc,
3908 AtomicOpValue &X, AtomicOpValue &V,
3909 AtomicOrdering AO) {
3910 if (!updateToLocation(Loc))
3911 return Loc.IP;
3912
3913 Type *XTy = X.Var->getType();
3914 assert(XTy->isPointerTy() && "OMP Atomic expects a pointer to target memory");
3915 Type *XElemTy = X.ElemTy;
3916 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
3917 XElemTy->isPointerTy()) &&
3918 "OMP atomic read expected a scalar type");
3919
3920 Value *XRead = nullptr;
3921
3922 if (XElemTy->isIntegerTy()) {
3923 LoadInst *XLD =
3924 Builder.CreateLoad(XElemTy, X.Var, X.IsVolatile, "omp.atomic.read");
3925 XLD->setAtomic(AO);
3926 XRead = cast<Value>(XLD);
3927 } else {
3928 // We need to bitcast and perform atomic op as integer
3929 unsigned Addrspace = cast<PointerType>(XTy)->getAddressSpace();
3930 IntegerType *IntCastTy =
3931 IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
3932 Value *XBCast = Builder.CreateBitCast(
3933 X.Var, IntCastTy->getPointerTo(Addrspace), "atomic.src.int.cast");
3934 LoadInst *XLoad =
3935 Builder.CreateLoad(IntCastTy, XBCast, X.IsVolatile, "omp.atomic.load");
3936 XLoad->setAtomic(AO);
3937 if (XElemTy->isFloatingPointTy()) {
3938 XRead = Builder.CreateBitCast(XLoad, XElemTy, "atomic.flt.cast");
3939 } else {
3940 XRead = Builder.CreateIntToPtr(XLoad, XElemTy, "atomic.ptr.cast");
3941 }
3942 }
3943 checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Read);
3944 Builder.CreateStore(XRead, V.Var, V.IsVolatile);
3945 return Builder.saveIP();
3946 }
3947
3948 OpenMPIRBuilder::InsertPointTy
createAtomicWrite(const LocationDescription & Loc,AtomicOpValue & X,Value * Expr,AtomicOrdering AO)3949 OpenMPIRBuilder::createAtomicWrite(const LocationDescription &Loc,
3950 AtomicOpValue &X, Value *Expr,
3951 AtomicOrdering AO) {
3952 if (!updateToLocation(Loc))
3953 return Loc.IP;
3954
3955 Type *XTy = X.Var->getType();
3956 assert(XTy->isPointerTy() && "OMP Atomic expects a pointer to target memory");
3957 Type *XElemTy = X.ElemTy;
3958 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
3959 XElemTy->isPointerTy()) &&
3960 "OMP atomic write expected a scalar type");
3961
3962 if (XElemTy->isIntegerTy()) {
3963 StoreInst *XSt = Builder.CreateStore(Expr, X.Var, X.IsVolatile);
3964 XSt->setAtomic(AO);
3965 } else {
3966 // We need to bitcast and perform atomic op as integers
3967 unsigned Addrspace = cast<PointerType>(XTy)->getAddressSpace();
3968 IntegerType *IntCastTy =
3969 IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
3970 Value *XBCast = Builder.CreateBitCast(
3971 X.Var, IntCastTy->getPointerTo(Addrspace), "atomic.dst.int.cast");
3972 Value *ExprCast =
3973 Builder.CreateBitCast(Expr, IntCastTy, "atomic.src.int.cast");
3974 StoreInst *XSt = Builder.CreateStore(ExprCast, XBCast, X.IsVolatile);
3975 XSt->setAtomic(AO);
3976 }
3977
3978 checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Write);
3979 return Builder.saveIP();
3980 }
3981
createAtomicUpdate(const LocationDescription & Loc,InsertPointTy AllocaIP,AtomicOpValue & X,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool IsXBinopExpr)3982 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicUpdate(
3983 const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
3984 Value *Expr, AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
3985 AtomicUpdateCallbackTy &UpdateOp, bool IsXBinopExpr) {
3986 assert(!isConflictIP(Loc.IP, AllocaIP) && "IPs must not be ambiguous");
3987 if (!updateToLocation(Loc))
3988 return Loc.IP;
3989
3990 LLVM_DEBUG({
3991 Type *XTy = X.Var->getType();
3992 assert(XTy->isPointerTy() &&
3993 "OMP Atomic expects a pointer to target memory");
3994 Type *XElemTy = X.ElemTy;
3995 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
3996 XElemTy->isPointerTy()) &&
3997 "OMP atomic update expected a scalar type");
3998 assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
3999 (RMWOp != AtomicRMWInst::UMax) && (RMWOp != AtomicRMWInst::UMin) &&
4000 "OpenMP atomic does not support LT or GT operations");
4001 });
4002
4003 emitAtomicUpdate(AllocaIP, X.Var, X.ElemTy, Expr, AO, RMWOp, UpdateOp,
4004 X.IsVolatile, IsXBinopExpr);
4005 checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Update);
4006 return Builder.saveIP();
4007 }
4008
emitRMWOpAsInstruction(Value * Src1,Value * Src2,AtomicRMWInst::BinOp RMWOp)4009 Value *OpenMPIRBuilder::emitRMWOpAsInstruction(Value *Src1, Value *Src2,
4010 AtomicRMWInst::BinOp RMWOp) {
4011 switch (RMWOp) {
4012 case AtomicRMWInst::Add:
4013 return Builder.CreateAdd(Src1, Src2);
4014 case AtomicRMWInst::Sub:
4015 return Builder.CreateSub(Src1, Src2);
4016 case AtomicRMWInst::And:
4017 return Builder.CreateAnd(Src1, Src2);
4018 case AtomicRMWInst::Nand:
4019 return Builder.CreateNeg(Builder.CreateAnd(Src1, Src2));
4020 case AtomicRMWInst::Or:
4021 return Builder.CreateOr(Src1, Src2);
4022 case AtomicRMWInst::Xor:
4023 return Builder.CreateXor(Src1, Src2);
4024 case AtomicRMWInst::Xchg:
4025 case AtomicRMWInst::FAdd:
4026 case AtomicRMWInst::FSub:
4027 case AtomicRMWInst::BAD_BINOP:
4028 case AtomicRMWInst::Max:
4029 case AtomicRMWInst::Min:
4030 case AtomicRMWInst::UMax:
4031 case AtomicRMWInst::UMin:
4032 case AtomicRMWInst::FMax:
4033 case AtomicRMWInst::FMin:
4034 llvm_unreachable("Unsupported atomic update operation");
4035 }
4036 llvm_unreachable("Unsupported atomic update operation");
4037 }
4038
emitAtomicUpdate(InsertPointTy AllocaIP,Value * X,Type * XElemTy,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool VolatileX,bool IsXBinopExpr)4039 std::pair<Value *, Value *> OpenMPIRBuilder::emitAtomicUpdate(
4040 InsertPointTy AllocaIP, Value *X, Type *XElemTy, Value *Expr,
4041 AtomicOrdering AO, AtomicRMWInst::BinOp RMWOp,
4042 AtomicUpdateCallbackTy &UpdateOp, bool VolatileX, bool IsXBinopExpr) {
4043 // TODO: handle the case where XElemTy is not byte-sized or not a power of 2
4044 // or a complex datatype.
4045 bool emitRMWOp = false;
4046 switch (RMWOp) {
4047 case AtomicRMWInst::Add:
4048 case AtomicRMWInst::And:
4049 case AtomicRMWInst::Nand:
4050 case AtomicRMWInst::Or:
4051 case AtomicRMWInst::Xor:
4052 case AtomicRMWInst::Xchg:
4053 emitRMWOp = XElemTy;
4054 break;
4055 case AtomicRMWInst::Sub:
4056 emitRMWOp = (IsXBinopExpr && XElemTy);
4057 break;
4058 default:
4059 emitRMWOp = false;
4060 }
4061 emitRMWOp &= XElemTy->isIntegerTy();
4062
4063 std::pair<Value *, Value *> Res;
4064 if (emitRMWOp) {
4065 Res.first = Builder.CreateAtomicRMW(RMWOp, X, Expr, llvm::MaybeAlign(), AO);
4066 // not needed except in case of postfix captures. Generate anyway for
4067 // consistency with the else part. Will be removed with any DCE pass.
4068 // AtomicRMWInst::Xchg does not have a coressponding instruction.
4069 if (RMWOp == AtomicRMWInst::Xchg)
4070 Res.second = Res.first;
4071 else
4072 Res.second = emitRMWOpAsInstruction(Res.first, Expr, RMWOp);
4073 } else {
4074 unsigned Addrspace = cast<PointerType>(X->getType())->getAddressSpace();
4075 IntegerType *IntCastTy =
4076 IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
4077 Value *XBCast =
4078 Builder.CreateBitCast(X, IntCastTy->getPointerTo(Addrspace));
4079 LoadInst *OldVal =
4080 Builder.CreateLoad(IntCastTy, XBCast, X->getName() + ".atomic.load");
4081 OldVal->setAtomic(AO);
4082 // CurBB
4083 // | /---\
4084 // ContBB |
4085 // | \---/
4086 // ExitBB
4087 BasicBlock *CurBB = Builder.GetInsertBlock();
4088 Instruction *CurBBTI = CurBB->getTerminator();
4089 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
4090 BasicBlock *ExitBB =
4091 CurBB->splitBasicBlock(CurBBTI, X->getName() + ".atomic.exit");
4092 BasicBlock *ContBB = CurBB->splitBasicBlock(CurBB->getTerminator(),
4093 X->getName() + ".atomic.cont");
4094 ContBB->getTerminator()->eraseFromParent();
4095 Builder.restoreIP(AllocaIP);
4096 AllocaInst *NewAtomicAddr = Builder.CreateAlloca(XElemTy);
4097 NewAtomicAddr->setName(X->getName() + "x.new.val");
4098 Builder.SetInsertPoint(ContBB);
4099 llvm::PHINode *PHI = Builder.CreatePHI(OldVal->getType(), 2);
4100 PHI->addIncoming(OldVal, CurBB);
4101 IntegerType *NewAtomicCastTy =
4102 IntegerType::get(M.getContext(), XElemTy->getScalarSizeInBits());
4103 bool IsIntTy = XElemTy->isIntegerTy();
4104 Value *NewAtomicIntAddr =
4105 (IsIntTy)
4106 ? NewAtomicAddr
4107 : Builder.CreateBitCast(NewAtomicAddr,
4108 NewAtomicCastTy->getPointerTo(Addrspace));
4109 Value *OldExprVal = PHI;
4110 if (!IsIntTy) {
4111 if (XElemTy->isFloatingPointTy()) {
4112 OldExprVal = Builder.CreateBitCast(PHI, XElemTy,
4113 X->getName() + ".atomic.fltCast");
4114 } else {
4115 OldExprVal = Builder.CreateIntToPtr(PHI, XElemTy,
4116 X->getName() + ".atomic.ptrCast");
4117 }
4118 }
4119
4120 Value *Upd = UpdateOp(OldExprVal, Builder);
4121 Builder.CreateStore(Upd, NewAtomicAddr);
4122 LoadInst *DesiredVal = Builder.CreateLoad(IntCastTy, NewAtomicIntAddr);
4123 Value *XAddr =
4124 (IsIntTy)
4125 ? X
4126 : Builder.CreateBitCast(X, IntCastTy->getPointerTo(Addrspace));
4127 AtomicOrdering Failure =
4128 llvm::AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
4129 AtomicCmpXchgInst *Result = Builder.CreateAtomicCmpXchg(
4130 XAddr, PHI, DesiredVal, llvm::MaybeAlign(), AO, Failure);
4131 Result->setVolatile(VolatileX);
4132 Value *PreviousVal = Builder.CreateExtractValue(Result, /*Idxs=*/0);
4133 Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
4134 PHI->addIncoming(PreviousVal, Builder.GetInsertBlock());
4135 Builder.CreateCondBr(SuccessFailureVal, ExitBB, ContBB);
4136
4137 Res.first = OldExprVal;
4138 Res.second = Upd;
4139
4140 // set Insertion point in exit block
4141 if (UnreachableInst *ExitTI =
4142 dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
4143 CurBBTI->eraseFromParent();
4144 Builder.SetInsertPoint(ExitBB);
4145 } else {
4146 Builder.SetInsertPoint(ExitTI);
4147 }
4148 }
4149
4150 return Res;
4151 }
4152
createAtomicCapture(const LocationDescription & Loc,InsertPointTy AllocaIP,AtomicOpValue & X,AtomicOpValue & V,Value * Expr,AtomicOrdering AO,AtomicRMWInst::BinOp RMWOp,AtomicUpdateCallbackTy & UpdateOp,bool UpdateExpr,bool IsPostfixUpdate,bool IsXBinopExpr)4153 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCapture(
4154 const LocationDescription &Loc, InsertPointTy AllocaIP, AtomicOpValue &X,
4155 AtomicOpValue &V, Value *Expr, AtomicOrdering AO,
4156 AtomicRMWInst::BinOp RMWOp, AtomicUpdateCallbackTy &UpdateOp,
4157 bool UpdateExpr, bool IsPostfixUpdate, bool IsXBinopExpr) {
4158 if (!updateToLocation(Loc))
4159 return Loc.IP;
4160
4161 LLVM_DEBUG({
4162 Type *XTy = X.Var->getType();
4163 assert(XTy->isPointerTy() &&
4164 "OMP Atomic expects a pointer to target memory");
4165 Type *XElemTy = X.ElemTy;
4166 assert((XElemTy->isFloatingPointTy() || XElemTy->isIntegerTy() ||
4167 XElemTy->isPointerTy()) &&
4168 "OMP atomic capture expected a scalar type");
4169 assert((RMWOp != AtomicRMWInst::Max) && (RMWOp != AtomicRMWInst::Min) &&
4170 "OpenMP atomic does not support LT or GT operations");
4171 });
4172
4173 // If UpdateExpr is 'x' updated with some `expr` not based on 'x',
4174 // 'x' is simply atomically rewritten with 'expr'.
4175 AtomicRMWInst::BinOp AtomicOp = (UpdateExpr ? RMWOp : AtomicRMWInst::Xchg);
4176 std::pair<Value *, Value *> Result =
4177 emitAtomicUpdate(AllocaIP, X.Var, X.ElemTy, Expr, AO, AtomicOp, UpdateOp,
4178 X.IsVolatile, IsXBinopExpr);
4179
4180 Value *CapturedVal = (IsPostfixUpdate ? Result.first : Result.second);
4181 Builder.CreateStore(CapturedVal, V.Var, V.IsVolatile);
4182
4183 checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Capture);
4184 return Builder.saveIP();
4185 }
4186
createAtomicCompare(const LocationDescription & Loc,AtomicOpValue & X,AtomicOpValue & V,AtomicOpValue & R,Value * E,Value * D,AtomicOrdering AO,omp::OMPAtomicCompareOp Op,bool IsXBinopExpr,bool IsPostfixUpdate,bool IsFailOnly)4187 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
4188 const LocationDescription &Loc, AtomicOpValue &X, AtomicOpValue &V,
4189 AtomicOpValue &R, Value *E, Value *D, AtomicOrdering AO,
4190 omp::OMPAtomicCompareOp Op, bool IsXBinopExpr, bool IsPostfixUpdate,
4191 bool IsFailOnly) {
4192
4193 if (!updateToLocation(Loc))
4194 return Loc.IP;
4195
4196 assert(X.Var->getType()->isPointerTy() &&
4197 "OMP atomic expects a pointer to target memory");
4198 // compare capture
4199 if (V.Var) {
4200 assert(V.Var->getType()->isPointerTy() && "v.var must be of pointer type");
4201 assert(V.ElemTy == X.ElemTy && "x and v must be of same type");
4202 }
4203
4204 bool IsInteger = E->getType()->isIntegerTy();
4205
4206 if (Op == OMPAtomicCompareOp::EQ) {
4207 AtomicOrdering Failure = AtomicCmpXchgInst::getStrongestFailureOrdering(AO);
4208 AtomicCmpXchgInst *Result = nullptr;
4209 if (!IsInteger) {
4210 unsigned Addrspace =
4211 cast<PointerType>(X.Var->getType())->getAddressSpace();
4212 IntegerType *IntCastTy =
4213 IntegerType::get(M.getContext(), X.ElemTy->getScalarSizeInBits());
4214 Value *XBCast =
4215 Builder.CreateBitCast(X.Var, IntCastTy->getPointerTo(Addrspace));
4216 Value *EBCast = Builder.CreateBitCast(E, IntCastTy);
4217 Value *DBCast = Builder.CreateBitCast(D, IntCastTy);
4218 Result = Builder.CreateAtomicCmpXchg(XBCast, EBCast, DBCast, MaybeAlign(),
4219 AO, Failure);
4220 } else {
4221 Result =
4222 Builder.CreateAtomicCmpXchg(X.Var, E, D, MaybeAlign(), AO, Failure);
4223 }
4224
4225 if (V.Var) {
4226 Value *OldValue = Builder.CreateExtractValue(Result, /*Idxs=*/0);
4227 if (!IsInteger)
4228 OldValue = Builder.CreateBitCast(OldValue, X.ElemTy);
4229 assert(OldValue->getType() == V.ElemTy &&
4230 "OldValue and V must be of same type");
4231 if (IsPostfixUpdate) {
4232 Builder.CreateStore(OldValue, V.Var, V.IsVolatile);
4233 } else {
4234 Value *SuccessOrFail = Builder.CreateExtractValue(Result, /*Idxs=*/1);
4235 if (IsFailOnly) {
4236 // CurBB----
4237 // | |
4238 // v |
4239 // ContBB |
4240 // | |
4241 // v |
4242 // ExitBB <-
4243 //
4244 // where ContBB only contains the store of old value to 'v'.
4245 BasicBlock *CurBB = Builder.GetInsertBlock();
4246 Instruction *CurBBTI = CurBB->getTerminator();
4247 CurBBTI = CurBBTI ? CurBBTI : Builder.CreateUnreachable();
4248 BasicBlock *ExitBB = CurBB->splitBasicBlock(
4249 CurBBTI, X.Var->getName() + ".atomic.exit");
4250 BasicBlock *ContBB = CurBB->splitBasicBlock(
4251 CurBB->getTerminator(), X.Var->getName() + ".atomic.cont");
4252 ContBB->getTerminator()->eraseFromParent();
4253 CurBB->getTerminator()->eraseFromParent();
4254
4255 Builder.CreateCondBr(SuccessOrFail, ExitBB, ContBB);
4256
4257 Builder.SetInsertPoint(ContBB);
4258 Builder.CreateStore(OldValue, V.Var);
4259 Builder.CreateBr(ExitBB);
4260
4261 if (UnreachableInst *ExitTI =
4262 dyn_cast<UnreachableInst>(ExitBB->getTerminator())) {
4263 CurBBTI->eraseFromParent();
4264 Builder.SetInsertPoint(ExitBB);
4265 } else {
4266 Builder.SetInsertPoint(ExitTI);
4267 }
4268 } else {
4269 Value *CapturedValue =
4270 Builder.CreateSelect(SuccessOrFail, E, OldValue);
4271 Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
4272 }
4273 }
4274 }
4275 // The comparison result has to be stored.
4276 if (R.Var) {
4277 assert(R.Var->getType()->isPointerTy() &&
4278 "r.var must be of pointer type");
4279 assert(R.ElemTy->isIntegerTy() && "r must be of integral type");
4280
4281 Value *SuccessFailureVal = Builder.CreateExtractValue(Result, /*Idxs=*/1);
4282 Value *ResultCast = R.IsSigned
4283 ? Builder.CreateSExt(SuccessFailureVal, R.ElemTy)
4284 : Builder.CreateZExt(SuccessFailureVal, R.ElemTy);
4285 Builder.CreateStore(ResultCast, R.Var, R.IsVolatile);
4286 }
4287 } else {
4288 assert((Op == OMPAtomicCompareOp::MAX || Op == OMPAtomicCompareOp::MIN) &&
4289 "Op should be either max or min at this point");
4290 assert(!IsFailOnly && "IsFailOnly is only valid when the comparison is ==");
4291
4292 // Reverse the ordop as the OpenMP forms are different from LLVM forms.
4293 // Let's take max as example.
4294 // OpenMP form:
4295 // x = x > expr ? expr : x;
4296 // LLVM form:
4297 // *ptr = *ptr > val ? *ptr : val;
4298 // We need to transform to LLVM form.
4299 // x = x <= expr ? x : expr;
4300 AtomicRMWInst::BinOp NewOp;
4301 if (IsXBinopExpr) {
4302 if (IsInteger) {
4303 if (X.IsSigned)
4304 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Min
4305 : AtomicRMWInst::Max;
4306 else
4307 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMin
4308 : AtomicRMWInst::UMax;
4309 } else {
4310 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMin
4311 : AtomicRMWInst::FMax;
4312 }
4313 } else {
4314 if (IsInteger) {
4315 if (X.IsSigned)
4316 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::Max
4317 : AtomicRMWInst::Min;
4318 else
4319 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::UMax
4320 : AtomicRMWInst::UMin;
4321 } else {
4322 NewOp = Op == OMPAtomicCompareOp::MAX ? AtomicRMWInst::FMax
4323 : AtomicRMWInst::FMin;
4324 }
4325 }
4326
4327 AtomicRMWInst *OldValue =
4328 Builder.CreateAtomicRMW(NewOp, X.Var, E, MaybeAlign(), AO);
4329 if (V.Var) {
4330 Value *CapturedValue = nullptr;
4331 if (IsPostfixUpdate) {
4332 CapturedValue = OldValue;
4333 } else {
4334 CmpInst::Predicate Pred;
4335 switch (NewOp) {
4336 case AtomicRMWInst::Max:
4337 Pred = CmpInst::ICMP_SGT;
4338 break;
4339 case AtomicRMWInst::UMax:
4340 Pred = CmpInst::ICMP_UGT;
4341 break;
4342 case AtomicRMWInst::FMax:
4343 Pred = CmpInst::FCMP_OGT;
4344 break;
4345 case AtomicRMWInst::Min:
4346 Pred = CmpInst::ICMP_SLT;
4347 break;
4348 case AtomicRMWInst::UMin:
4349 Pred = CmpInst::ICMP_ULT;
4350 break;
4351 case AtomicRMWInst::FMin:
4352 Pred = CmpInst::FCMP_OLT;
4353 break;
4354 default:
4355 llvm_unreachable("unexpected comparison op");
4356 }
4357 Value *NonAtomicCmp = Builder.CreateCmp(Pred, OldValue, E);
4358 CapturedValue = Builder.CreateSelect(NonAtomicCmp, E, OldValue);
4359 }
4360 Builder.CreateStore(CapturedValue, V.Var, V.IsVolatile);
4361 }
4362 }
4363
4364 checkAndEmitFlushAfterAtomic(Loc, AO, AtomicKind::Compare);
4365
4366 return Builder.saveIP();
4367 }
4368
4369 GlobalVariable *
createOffloadMapnames(SmallVectorImpl<llvm::Constant * > & Names,std::string VarName)4370 OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
4371 std::string VarName) {
4372 llvm::Constant *MapNamesArrayInit = llvm::ConstantArray::get(
4373 llvm::ArrayType::get(
4374 llvm::Type::getInt8Ty(M.getContext())->getPointerTo(), Names.size()),
4375 Names);
4376 auto *MapNamesArrayGlobal = new llvm::GlobalVariable(
4377 M, MapNamesArrayInit->getType(),
4378 /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, MapNamesArrayInit,
4379 VarName);
4380 return MapNamesArrayGlobal;
4381 }
4382
4383 // Create all simple and struct types exposed by the runtime and remember
4384 // the llvm::PointerTypes of them for easy access later.
initializeTypes(Module & M)4385 void OpenMPIRBuilder::initializeTypes(Module &M) {
4386 LLVMContext &Ctx = M.getContext();
4387 StructType *T;
4388 #define OMP_TYPE(VarName, InitValue) VarName = InitValue;
4389 #define OMP_ARRAY_TYPE(VarName, ElemTy, ArraySize) \
4390 VarName##Ty = ArrayType::get(ElemTy, ArraySize); \
4391 VarName##PtrTy = PointerType::getUnqual(VarName##Ty);
4392 #define OMP_FUNCTION_TYPE(VarName, IsVarArg, ReturnType, ...) \
4393 VarName = FunctionType::get(ReturnType, {__VA_ARGS__}, IsVarArg); \
4394 VarName##Ptr = PointerType::getUnqual(VarName);
4395 #define OMP_STRUCT_TYPE(VarName, StructName, ...) \
4396 T = StructType::getTypeByName(Ctx, StructName); \
4397 if (!T) \
4398 T = StructType::create(Ctx, {__VA_ARGS__}, StructName); \
4399 VarName = T; \
4400 VarName##Ptr = PointerType::getUnqual(T);
4401 #include "llvm/Frontend/OpenMP/OMPKinds.def"
4402 }
4403
collectBlocks(SmallPtrSetImpl<BasicBlock * > & BlockSet,SmallVectorImpl<BasicBlock * > & BlockVector)4404 void OpenMPIRBuilder::OutlineInfo::collectBlocks(
4405 SmallPtrSetImpl<BasicBlock *> &BlockSet,
4406 SmallVectorImpl<BasicBlock *> &BlockVector) {
4407 SmallVector<BasicBlock *, 32> Worklist;
4408 BlockSet.insert(EntryBB);
4409 BlockSet.insert(ExitBB);
4410
4411 Worklist.push_back(EntryBB);
4412 while (!Worklist.empty()) {
4413 BasicBlock *BB = Worklist.pop_back_val();
4414 BlockVector.push_back(BB);
4415 for (BasicBlock *SuccBB : successors(BB))
4416 if (BlockSet.insert(SuccBB).second)
4417 Worklist.push_back(SuccBB);
4418 }
4419 }
4420
collectControlBlocks(SmallVectorImpl<BasicBlock * > & BBs)4421 void CanonicalLoopInfo::collectControlBlocks(
4422 SmallVectorImpl<BasicBlock *> &BBs) {
4423 // We only count those BBs as control block for which we do not need to
4424 // reverse the CFG, i.e. not the loop body which can contain arbitrary control
4425 // flow. For consistency, this also means we do not add the Body block, which
4426 // is just the entry to the body code.
4427 BBs.reserve(BBs.size() + 6);
4428 BBs.append({getPreheader(), Header, Cond, Latch, Exit, getAfter()});
4429 }
4430
getPreheader() const4431 BasicBlock *CanonicalLoopInfo::getPreheader() const {
4432 assert(isValid() && "Requires a valid canonical loop");
4433 for (BasicBlock *Pred : predecessors(Header)) {
4434 if (Pred != Latch)
4435 return Pred;
4436 }
4437 llvm_unreachable("Missing preheader");
4438 }
4439
setTripCount(Value * TripCount)4440 void CanonicalLoopInfo::setTripCount(Value *TripCount) {
4441 assert(isValid() && "Requires a valid canonical loop");
4442
4443 Instruction *CmpI = &getCond()->front();
4444 assert(isa<CmpInst>(CmpI) && "First inst must compare IV with TripCount");
4445 CmpI->setOperand(1, TripCount);
4446
4447 #ifndef NDEBUG
4448 assertOK();
4449 #endif
4450 }
4451
mapIndVar(llvm::function_ref<Value * (Instruction *)> Updater)4452 void CanonicalLoopInfo::mapIndVar(
4453 llvm::function_ref<Value *(Instruction *)> Updater) {
4454 assert(isValid() && "Requires a valid canonical loop");
4455
4456 Instruction *OldIV = getIndVar();
4457
4458 // Record all uses excluding those introduced by the updater. Uses by the
4459 // CanonicalLoopInfo itself to keep track of the number of iterations are
4460 // excluded.
4461 SmallVector<Use *> ReplacableUses;
4462 for (Use &U : OldIV->uses()) {
4463 auto *User = dyn_cast<Instruction>(U.getUser());
4464 if (!User)
4465 continue;
4466 if (User->getParent() == getCond())
4467 continue;
4468 if (User->getParent() == getLatch())
4469 continue;
4470 ReplacableUses.push_back(&U);
4471 }
4472
4473 // Run the updater that may introduce new uses
4474 Value *NewIV = Updater(OldIV);
4475
4476 // Replace the old uses with the value returned by the updater.
4477 for (Use *U : ReplacableUses)
4478 U->set(NewIV);
4479
4480 #ifndef NDEBUG
4481 assertOK();
4482 #endif
4483 }
4484
assertOK() const4485 void CanonicalLoopInfo::assertOK() const {
4486 #ifndef NDEBUG
4487 // No constraints if this object currently does not describe a loop.
4488 if (!isValid())
4489 return;
4490
4491 BasicBlock *Preheader = getPreheader();
4492 BasicBlock *Body = getBody();
4493 BasicBlock *After = getAfter();
4494
4495 // Verify standard control-flow we use for OpenMP loops.
4496 assert(Preheader);
4497 assert(isa<BranchInst>(Preheader->getTerminator()) &&
4498 "Preheader must terminate with unconditional branch");
4499 assert(Preheader->getSingleSuccessor() == Header &&
4500 "Preheader must jump to header");
4501
4502 assert(Header);
4503 assert(isa<BranchInst>(Header->getTerminator()) &&
4504 "Header must terminate with unconditional branch");
4505 assert(Header->getSingleSuccessor() == Cond &&
4506 "Header must jump to exiting block");
4507
4508 assert(Cond);
4509 assert(Cond->getSinglePredecessor() == Header &&
4510 "Exiting block only reachable from header");
4511
4512 assert(isa<BranchInst>(Cond->getTerminator()) &&
4513 "Exiting block must terminate with conditional branch");
4514 assert(size(successors(Cond)) == 2 &&
4515 "Exiting block must have two successors");
4516 assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(0) == Body &&
4517 "Exiting block's first successor jump to the body");
4518 assert(cast<BranchInst>(Cond->getTerminator())->getSuccessor(1) == Exit &&
4519 "Exiting block's second successor must exit the loop");
4520
4521 assert(Body);
4522 assert(Body->getSinglePredecessor() == Cond &&
4523 "Body only reachable from exiting block");
4524 assert(!isa<PHINode>(Body->front()));
4525
4526 assert(Latch);
4527 assert(isa<BranchInst>(Latch->getTerminator()) &&
4528 "Latch must terminate with unconditional branch");
4529 assert(Latch->getSingleSuccessor() == Header && "Latch must jump to header");
4530 // TODO: To support simple redirecting of the end of the body code that has
4531 // multiple; introduce another auxiliary basic block like preheader and after.
4532 assert(Latch->getSinglePredecessor() != nullptr);
4533 assert(!isa<PHINode>(Latch->front()));
4534
4535 assert(Exit);
4536 assert(isa<BranchInst>(Exit->getTerminator()) &&
4537 "Exit block must terminate with unconditional branch");
4538 assert(Exit->getSingleSuccessor() == After &&
4539 "Exit block must jump to after block");
4540
4541 assert(After);
4542 assert(After->getSinglePredecessor() == Exit &&
4543 "After block only reachable from exit block");
4544 assert(After->empty() || !isa<PHINode>(After->front()));
4545
4546 Instruction *IndVar = getIndVar();
4547 assert(IndVar && "Canonical induction variable not found?");
4548 assert(isa<IntegerType>(IndVar->getType()) &&
4549 "Induction variable must be an integer");
4550 assert(cast<PHINode>(IndVar)->getParent() == Header &&
4551 "Induction variable must be a PHI in the loop header");
4552 assert(cast<PHINode>(IndVar)->getIncomingBlock(0) == Preheader);
4553 assert(
4554 cast<ConstantInt>(cast<PHINode>(IndVar)->getIncomingValue(0))->isZero());
4555 assert(cast<PHINode>(IndVar)->getIncomingBlock(1) == Latch);
4556
4557 auto *NextIndVar = cast<PHINode>(IndVar)->getIncomingValue(1);
4558 assert(cast<Instruction>(NextIndVar)->getParent() == Latch);
4559 assert(cast<BinaryOperator>(NextIndVar)->getOpcode() == BinaryOperator::Add);
4560 assert(cast<BinaryOperator>(NextIndVar)->getOperand(0) == IndVar);
4561 assert(cast<ConstantInt>(cast<BinaryOperator>(NextIndVar)->getOperand(1))
4562 ->isOne());
4563
4564 Value *TripCount = getTripCount();
4565 assert(TripCount && "Loop trip count not found?");
4566 assert(IndVar->getType() == TripCount->getType() &&
4567 "Trip count and induction variable must have the same type");
4568
4569 auto *CmpI = cast<CmpInst>(&Cond->front());
4570 assert(CmpI->getPredicate() == CmpInst::ICMP_ULT &&
4571 "Exit condition must be a signed less-than comparison");
4572 assert(CmpI->getOperand(0) == IndVar &&
4573 "Exit condition must compare the induction variable");
4574 assert(CmpI->getOperand(1) == TripCount &&
4575 "Exit condition must compare with the trip count");
4576 #endif
4577 }
4578
invalidate()4579 void CanonicalLoopInfo::invalidate() {
4580 Header = nullptr;
4581 Cond = nullptr;
4582 Latch = nullptr;
4583 Exit = nullptr;
4584 }
4585