1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpenMP specific optimizations:
10 //
11 // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12 // - Replacing globalized device memory with stack memory.
13 // - Replacing globalized device memory with shared memory.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/Transforms/IPO/OpenMPOpt.h"
18 
19 #include "llvm/ADT/EnumeratedArray.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/Statistic.h"
22 #include "llvm/Analysis/CallGraph.h"
23 #include "llvm/Analysis/CallGraphSCCPass.h"
24 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
25 #include "llvm/Analysis/ValueTracking.h"
26 #include "llvm/Frontend/OpenMP/OMPConstants.h"
27 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/IntrinsicsAMDGPU.h"
30 #include "llvm/IR/IntrinsicsNVPTX.h"
31 #include "llvm/IR/PatternMatch.h"
32 #include "llvm/InitializePasses.h"
33 #include "llvm/Support/CommandLine.h"
34 #include "llvm/Transforms/IPO.h"
35 #include "llvm/Transforms/IPO/Attributor.h"
36 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
37 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
38 #include "llvm/Transforms/Utils/CodeExtractor.h"
39 
40 using namespace llvm::PatternMatch;
41 using namespace llvm;
42 using namespace omp;
43 
44 #define DEBUG_TYPE "openmp-opt"
45 
46 static cl::opt<bool> DisableOpenMPOptimizations(
47     "openmp-opt-disable", cl::ZeroOrMore,
48     cl::desc("Disable OpenMP specific optimizations."), cl::Hidden,
49     cl::init(false));
50 
51 static cl::opt<bool> EnableParallelRegionMerging(
52     "openmp-opt-enable-merging", cl::ZeroOrMore,
53     cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
54     cl::init(false));
55 
56 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
57                                     cl::Hidden);
58 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
59                                         cl::init(false), cl::Hidden);
60 
61 static cl::opt<bool> HideMemoryTransferLatency(
62     "openmp-hide-memory-transfer-latency",
63     cl::desc("[WIP] Tries to hide the latency of host to device memory"
64              " transfers"),
65     cl::Hidden, cl::init(false));
66 
67 STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
68           "Number of OpenMP runtime calls deduplicated");
69 STATISTIC(NumOpenMPParallelRegionsDeleted,
70           "Number of OpenMP parallel regions deleted");
71 STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
72           "Number of OpenMP runtime functions identified");
73 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
74           "Number of OpenMP runtime function uses identified");
75 STATISTIC(NumOpenMPTargetRegionKernels,
76           "Number of OpenMP target region entry points (=kernels) identified");
77 STATISTIC(
78     NumOpenMPParallelRegionsReplacedInGPUStateMachine,
79     "Number of OpenMP parallel regions replaced with ID in GPU state machines");
80 STATISTIC(NumOpenMPParallelRegionsMerged,
81           "Number of OpenMP parallel regions merged");
82 STATISTIC(NumBytesMovedToSharedMemory,
83           "Amount of memory pushed to shared memory");
84 
85 #if !defined(NDEBUG)
86 static constexpr auto TAG = "[" DEBUG_TYPE "]";
87 #endif
88 
89 namespace {
90 
91 enum class AddressSpace : unsigned {
92   Generic = 0,
93   Global = 1,
94   Shared = 3,
95   Constant = 4,
96   Local = 5,
97 };
98 
99 struct AAHeapToShared;
100 
101 struct AAICVTracker;
102 
103 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
104 /// Attributor runs.
105 struct OMPInformationCache : public InformationCache {
106   OMPInformationCache(Module &M, AnalysisGetter &AG,
107                       BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC,
108                       SmallPtrSetImpl<Kernel> &Kernels)
109       : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M),
110         Kernels(Kernels) {
111 
112     OMPBuilder.initialize();
113     initializeRuntimeFunctions();
114     initializeInternalControlVars();
115   }
116 
117   /// Generic information that describes an internal control variable.
118   struct InternalControlVarInfo {
119     /// The kind, as described by InternalControlVar enum.
120     InternalControlVar Kind;
121 
122     /// The name of the ICV.
123     StringRef Name;
124 
125     /// Environment variable associated with this ICV.
126     StringRef EnvVarName;
127 
128     /// Initial value kind.
129     ICVInitValue InitKind;
130 
131     /// Initial value.
132     ConstantInt *InitValue;
133 
134     /// Setter RTL function associated with this ICV.
135     RuntimeFunction Setter;
136 
137     /// Getter RTL function associated with this ICV.
138     RuntimeFunction Getter;
139 
140     /// RTL Function corresponding to the override clause of this ICV
141     RuntimeFunction Clause;
142   };
143 
144   /// Generic information that describes a runtime function
145   struct RuntimeFunctionInfo {
146 
147     /// The kind, as described by the RuntimeFunction enum.
148     RuntimeFunction Kind;
149 
150     /// The name of the function.
151     StringRef Name;
152 
153     /// Flag to indicate a variadic function.
154     bool IsVarArg;
155 
156     /// The return type of the function.
157     Type *ReturnType;
158 
159     /// The argument types of the function.
160     SmallVector<Type *, 8> ArgumentTypes;
161 
162     /// The declaration if available.
163     Function *Declaration = nullptr;
164 
165     /// Uses of this runtime function per function containing the use.
166     using UseVector = SmallVector<Use *, 16>;
167 
168     /// Clear UsesMap for runtime function.
169     void clearUsesMap() { UsesMap.clear(); }
170 
171     /// Boolean conversion that is true if the runtime function was found.
172     operator bool() const { return Declaration; }
173 
174     /// Return the vector of uses in function \p F.
175     UseVector &getOrCreateUseVector(Function *F) {
176       std::shared_ptr<UseVector> &UV = UsesMap[F];
177       if (!UV)
178         UV = std::make_shared<UseVector>();
179       return *UV;
180     }
181 
182     /// Return the vector of uses in function \p F or `nullptr` if there are
183     /// none.
184     const UseVector *getUseVector(Function &F) const {
185       auto I = UsesMap.find(&F);
186       if (I != UsesMap.end())
187         return I->second.get();
188       return nullptr;
189     }
190 
191     /// Return how many functions contain uses of this runtime function.
192     size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
193 
194     /// Return the number of arguments (or the minimal number for variadic
195     /// functions).
196     size_t getNumArgs() const { return ArgumentTypes.size(); }
197 
198     /// Run the callback \p CB on each use and forget the use if the result is
199     /// true. The callback will be fed the function in which the use was
200     /// encountered as second argument.
201     void foreachUse(SmallVectorImpl<Function *> &SCC,
202                     function_ref<bool(Use &, Function &)> CB) {
203       for (Function *F : SCC)
204         foreachUse(CB, F);
205     }
206 
207     /// Run the callback \p CB on each use within the function \p F and forget
208     /// the use if the result is true.
209     void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
210       SmallVector<unsigned, 8> ToBeDeleted;
211       ToBeDeleted.clear();
212 
213       unsigned Idx = 0;
214       UseVector &UV = getOrCreateUseVector(F);
215 
216       for (Use *U : UV) {
217         if (CB(*U, *F))
218           ToBeDeleted.push_back(Idx);
219         ++Idx;
220       }
221 
222       // Remove the to-be-deleted indices in reverse order as prior
223       // modifications will not modify the smaller indices.
224       while (!ToBeDeleted.empty()) {
225         unsigned Idx = ToBeDeleted.pop_back_val();
226         UV[Idx] = UV.back();
227         UV.pop_back();
228       }
229     }
230 
231   private:
232     /// Map from functions to all uses of this runtime function contained in
233     /// them.
234     DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
235   };
236 
237   /// An OpenMP-IR-Builder instance
238   OpenMPIRBuilder OMPBuilder;
239 
240   /// Map from runtime function kind to the runtime function description.
241   EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
242                   RuntimeFunction::OMPRTL___last>
243       RFIs;
244 
245   /// Map from ICV kind to the ICV description.
246   EnumeratedArray<InternalControlVarInfo, InternalControlVar,
247                   InternalControlVar::ICV___last>
248       ICVs;
249 
250   /// Helper to initialize all internal control variable information for those
251   /// defined in OMPKinds.def.
252   void initializeInternalControlVars() {
253 #define ICV_RT_SET(_Name, RTL)                                                 \
254   {                                                                            \
255     auto &ICV = ICVs[_Name];                                                   \
256     ICV.Setter = RTL;                                                          \
257   }
258 #define ICV_RT_GET(Name, RTL)                                                  \
259   {                                                                            \
260     auto &ICV = ICVs[Name];                                                    \
261     ICV.Getter = RTL;                                                          \
262   }
263 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init)                           \
264   {                                                                            \
265     auto &ICV = ICVs[Enum];                                                    \
266     ICV.Name = _Name;                                                          \
267     ICV.Kind = Enum;                                                           \
268     ICV.InitKind = Init;                                                       \
269     ICV.EnvVarName = _EnvVarName;                                              \
270     switch (ICV.InitKind) {                                                    \
271     case ICV_IMPLEMENTATION_DEFINED:                                           \
272       ICV.InitValue = nullptr;                                                 \
273       break;                                                                   \
274     case ICV_ZERO:                                                             \
275       ICV.InitValue = ConstantInt::get(                                        \
276           Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0);                \
277       break;                                                                   \
278     case ICV_FALSE:                                                            \
279       ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext());    \
280       break;                                                                   \
281     case ICV_LAST:                                                             \
282       break;                                                                   \
283     }                                                                          \
284   }
285 #include "llvm/Frontend/OpenMP/OMPKinds.def"
286   }
287 
288   /// Returns true if the function declaration \p F matches the runtime
289   /// function types, that is, return type \p RTFRetType, and argument types
290   /// \p RTFArgTypes.
291   static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
292                                   SmallVector<Type *, 8> &RTFArgTypes) {
293     // TODO: We should output information to the user (under debug output
294     //       and via remarks).
295 
296     if (!F)
297       return false;
298     if (F->getReturnType() != RTFRetType)
299       return false;
300     if (F->arg_size() != RTFArgTypes.size())
301       return false;
302 
303     auto RTFTyIt = RTFArgTypes.begin();
304     for (Argument &Arg : F->args()) {
305       if (Arg.getType() != *RTFTyIt)
306         return false;
307 
308       ++RTFTyIt;
309     }
310 
311     return true;
312   }
313 
314   // Helper to collect all uses of the declaration in the UsesMap.
315   unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
316     unsigned NumUses = 0;
317     if (!RFI.Declaration)
318       return NumUses;
319     OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
320 
321     if (CollectStats) {
322       NumOpenMPRuntimeFunctionsIdentified += 1;
323       NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
324     }
325 
326     // TODO: We directly convert uses into proper calls and unknown uses.
327     for (Use &U : RFI.Declaration->uses()) {
328       if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
329         if (ModuleSlice.count(UserI->getFunction())) {
330           RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
331           ++NumUses;
332         }
333       } else {
334         RFI.getOrCreateUseVector(nullptr).push_back(&U);
335         ++NumUses;
336       }
337     }
338     return NumUses;
339   }
340 
341   // Helper function to recollect uses of a runtime function.
342   void recollectUsesForFunction(RuntimeFunction RTF) {
343     auto &RFI = RFIs[RTF];
344     RFI.clearUsesMap();
345     collectUses(RFI, /*CollectStats*/ false);
346   }
347 
348   // Helper function to recollect uses of all runtime functions.
349   void recollectUses() {
350     for (int Idx = 0; Idx < RFIs.size(); ++Idx)
351       recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
352   }
353 
354   /// Helper to initialize all runtime function information for those defined
355   /// in OpenMPKinds.def.
356   void initializeRuntimeFunctions() {
357     Module &M = *((*ModuleSlice.begin())->getParent());
358 
359     // Helper macros for handling __VA_ARGS__ in OMP_RTL
360 #define OMP_TYPE(VarName, ...)                                                 \
361   Type *VarName = OMPBuilder.VarName;                                          \
362   (void)VarName;
363 
364 #define OMP_ARRAY_TYPE(VarName, ...)                                           \
365   ArrayType *VarName##Ty = OMPBuilder.VarName##Ty;                             \
366   (void)VarName##Ty;                                                           \
367   PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy;                     \
368   (void)VarName##PtrTy;
369 
370 #define OMP_FUNCTION_TYPE(VarName, ...)                                        \
371   FunctionType *VarName = OMPBuilder.VarName;                                  \
372   (void)VarName;                                                               \
373   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
374   (void)VarName##Ptr;
375 
376 #define OMP_STRUCT_TYPE(VarName, ...)                                          \
377   StructType *VarName = OMPBuilder.VarName;                                    \
378   (void)VarName;                                                               \
379   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
380   (void)VarName##Ptr;
381 
382 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
383   {                                                                            \
384     SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
385     Function *F = M.getFunction(_Name);                                        \
386     if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) {           \
387       auto &RFI = RFIs[_Enum];                                                 \
388       RFI.Kind = _Enum;                                                        \
389       RFI.Name = _Name;                                                        \
390       RFI.IsVarArg = _IsVarArg;                                                \
391       RFI.ReturnType = OMPBuilder._ReturnType;                                 \
392       RFI.ArgumentTypes = std::move(ArgsTypes);                                \
393       RFI.Declaration = F;                                                     \
394       unsigned NumUses = collectUses(RFI);                                     \
395       (void)NumUses;                                                           \
396       LLVM_DEBUG({                                                             \
397         dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
398                << " found\n";                                                  \
399         if (RFI.Declaration)                                                   \
400           dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
401                  << RFI.getNumFunctionsWithUses()                              \
402                  << " different functions.\n";                                 \
403       });                                                                      \
404     }                                                                          \
405   }
406 #include "llvm/Frontend/OpenMP/OMPKinds.def"
407 
408     // TODO: We should attach the attributes defined in OMPKinds.def.
409   }
410 
411   /// Collection of known kernels (\see Kernel) in the module.
412   SmallPtrSetImpl<Kernel> &Kernels;
413 };
414 
415 /// Used to map the values physically (in the IR) stored in an offload
416 /// array, to a vector in memory.
417 struct OffloadArray {
418   /// Physical array (in the IR).
419   AllocaInst *Array = nullptr;
420   /// Mapped values.
421   SmallVector<Value *, 8> StoredValues;
422   /// Last stores made in the offload array.
423   SmallVector<StoreInst *, 8> LastAccesses;
424 
425   OffloadArray() = default;
426 
427   /// Initializes the OffloadArray with the values stored in \p Array before
428   /// instruction \p Before is reached. Returns false if the initialization
429   /// fails.
430   /// This MUST be used immediately after the construction of the object.
431   bool initialize(AllocaInst &Array, Instruction &Before) {
432     if (!Array.getAllocatedType()->isArrayTy())
433       return false;
434 
435     if (!getValues(Array, Before))
436       return false;
437 
438     this->Array = &Array;
439     return true;
440   }
441 
442   static const unsigned DeviceIDArgNum = 1;
443   static const unsigned BasePtrsArgNum = 3;
444   static const unsigned PtrsArgNum = 4;
445   static const unsigned SizesArgNum = 5;
446 
447 private:
448   /// Traverses the BasicBlock where \p Array is, collecting the stores made to
449   /// \p Array, leaving StoredValues with the values stored before the
450   /// instruction \p Before is reached.
451   bool getValues(AllocaInst &Array, Instruction &Before) {
452     // Initialize container.
453     const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
454     StoredValues.assign(NumValues, nullptr);
455     LastAccesses.assign(NumValues, nullptr);
456 
457     // TODO: This assumes the instruction \p Before is in the same
458     //  BasicBlock as Array. Make it general, for any control flow graph.
459     BasicBlock *BB = Array.getParent();
460     if (BB != Before.getParent())
461       return false;
462 
463     const DataLayout &DL = Array.getModule()->getDataLayout();
464     const unsigned int PointerSize = DL.getPointerSize();
465 
466     for (Instruction &I : *BB) {
467       if (&I == &Before)
468         break;
469 
470       if (!isa<StoreInst>(&I))
471         continue;
472 
473       auto *S = cast<StoreInst>(&I);
474       int64_t Offset = -1;
475       auto *Dst =
476           GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
477       if (Dst == &Array) {
478         int64_t Idx = Offset / PointerSize;
479         StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
480         LastAccesses[Idx] = S;
481       }
482     }
483 
484     return isFilled();
485   }
486 
487   /// Returns true if all values in StoredValues and
488   /// LastAccesses are not nullptrs.
489   bool isFilled() {
490     const unsigned NumValues = StoredValues.size();
491     for (unsigned I = 0; I < NumValues; ++I) {
492       if (!StoredValues[I] || !LastAccesses[I])
493         return false;
494     }
495 
496     return true;
497   }
498 };
499 
500 struct OpenMPOpt {
501 
502   using OptimizationRemarkGetter =
503       function_ref<OptimizationRemarkEmitter &(Function *)>;
504 
505   OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
506             OptimizationRemarkGetter OREGetter,
507             OMPInformationCache &OMPInfoCache, Attributor &A)
508       : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
509         OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
510 
511   /// Check if any remarks are enabled for openmp-opt
512   bool remarksEnabled() {
513     auto &Ctx = M.getContext();
514     return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
515   }
516 
517   /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
518   bool run(bool IsModulePass) {
519     if (SCC.empty())
520       return false;
521 
522     bool Changed = false;
523 
524     LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
525                       << " functions in a slice with "
526                       << OMPInfoCache.ModuleSlice.size() << " functions\n");
527 
528     if (IsModulePass) {
529       Changed |= runAttributor();
530 
531       // Recollect uses, in case Attributor deleted any.
532       OMPInfoCache.recollectUses();
533 
534       if (remarksEnabled())
535         analysisGlobalization();
536     } else {
537       if (PrintICVValues)
538         printICVs();
539       if (PrintOpenMPKernels)
540         printKernels();
541 
542       Changed |= rewriteDeviceCodeStateMachine();
543 
544       Changed |= runAttributor();
545 
546       // Recollect uses, in case Attributor deleted any.
547       OMPInfoCache.recollectUses();
548 
549       Changed |= deleteParallelRegions();
550       if (HideMemoryTransferLatency)
551         Changed |= hideMemTransfersLatency();
552       Changed |= deduplicateRuntimeCalls();
553       if (EnableParallelRegionMerging) {
554         if (mergeParallelRegions()) {
555           deduplicateRuntimeCalls();
556           Changed = true;
557         }
558       }
559     }
560 
561     return Changed;
562   }
563 
564   /// Print initial ICV values for testing.
565   /// FIXME: This should be done from the Attributor once it is added.
566   void printICVs() const {
567     InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
568                                  ICV_proc_bind};
569 
570     for (Function *F : OMPInfoCache.ModuleSlice) {
571       for (auto ICV : ICVs) {
572         auto ICVInfo = OMPInfoCache.ICVs[ICV];
573         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
574           return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
575                      << " Value: "
576                      << (ICVInfo.InitValue
577                              ? toString(ICVInfo.InitValue->getValue(), 10, true)
578                              : "IMPLEMENTATION_DEFINED");
579         };
580 
581         emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
582       }
583     }
584   }
585 
586   /// Print OpenMP GPU kernels for testing.
587   void printKernels() const {
588     for (Function *F : SCC) {
589       if (!OMPInfoCache.Kernels.count(F))
590         continue;
591 
592       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
593         return ORA << "OpenMP GPU kernel "
594                    << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
595       };
596 
597       emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
598     }
599   }
600 
601   /// Return the call if \p U is a callee use in a regular call. If \p RFI is
602   /// given it has to be the callee or a nullptr is returned.
603   static CallInst *getCallIfRegularCall(
604       Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
605     CallInst *CI = dyn_cast<CallInst>(U.getUser());
606     if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
607         (!RFI || CI->getCalledFunction() == RFI->Declaration))
608       return CI;
609     return nullptr;
610   }
611 
612   /// Return the call if \p V is a regular call. If \p RFI is given it has to be
613   /// the callee or a nullptr is returned.
614   static CallInst *getCallIfRegularCall(
615       Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
616     CallInst *CI = dyn_cast<CallInst>(&V);
617     if (CI && !CI->hasOperandBundles() &&
618         (!RFI || CI->getCalledFunction() == RFI->Declaration))
619       return CI;
620     return nullptr;
621   }
622 
623 private:
624   /// Merge parallel regions when it is safe.
625   bool mergeParallelRegions() {
626     const unsigned CallbackCalleeOperand = 2;
627     const unsigned CallbackFirstArgOperand = 3;
628     using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
629 
630     // Check if there are any __kmpc_fork_call calls to merge.
631     OMPInformationCache::RuntimeFunctionInfo &RFI =
632         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
633 
634     if (!RFI.Declaration)
635       return false;
636 
637     // Unmergable calls that prevent merging a parallel region.
638     OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
639         OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
640         OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
641     };
642 
643     bool Changed = false;
644     LoopInfo *LI = nullptr;
645     DominatorTree *DT = nullptr;
646 
647     SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
648 
649     BasicBlock *StartBB = nullptr, *EndBB = nullptr;
650     auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
651                          BasicBlock &ContinuationIP) {
652       BasicBlock *CGStartBB = CodeGenIP.getBlock();
653       BasicBlock *CGEndBB =
654           SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
655       assert(StartBB != nullptr && "StartBB should not be null");
656       CGStartBB->getTerminator()->setSuccessor(0, StartBB);
657       assert(EndBB != nullptr && "EndBB should not be null");
658       EndBB->getTerminator()->setSuccessor(0, CGEndBB);
659     };
660 
661     auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
662                       Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
663       ReplacementValue = &Inner;
664       return CodeGenIP;
665     };
666 
667     auto FiniCB = [&](InsertPointTy CodeGenIP) {};
668 
669     /// Create a sequential execution region within a merged parallel region,
670     /// encapsulated in a master construct with a barrier for synchronization.
671     auto CreateSequentialRegion = [&](Function *OuterFn,
672                                       BasicBlock *OuterPredBB,
673                                       Instruction *SeqStartI,
674                                       Instruction *SeqEndI) {
675       // Isolate the instructions of the sequential region to a separate
676       // block.
677       BasicBlock *ParentBB = SeqStartI->getParent();
678       BasicBlock *SeqEndBB =
679           SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
680       BasicBlock *SeqAfterBB =
681           SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
682       BasicBlock *SeqStartBB =
683           SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
684 
685       assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
686              "Expected a different CFG");
687       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
688       ParentBB->getTerminator()->eraseFromParent();
689 
690       auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
691                            BasicBlock &ContinuationIP) {
692         BasicBlock *CGStartBB = CodeGenIP.getBlock();
693         BasicBlock *CGEndBB =
694             SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
695         assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
696         CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
697         assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
698         SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
699       };
700       auto FiniCB = [&](InsertPointTy CodeGenIP) {};
701 
702       // Find outputs from the sequential region to outside users and
703       // broadcast their values to them.
704       for (Instruction &I : *SeqStartBB) {
705         SmallPtrSet<Instruction *, 4> OutsideUsers;
706         for (User *Usr : I.users()) {
707           Instruction &UsrI = *cast<Instruction>(Usr);
708           // Ignore outputs to LT intrinsics, code extraction for the merged
709           // parallel region will fix them.
710           if (UsrI.isLifetimeStartOrEnd())
711             continue;
712 
713           if (UsrI.getParent() != SeqStartBB)
714             OutsideUsers.insert(&UsrI);
715         }
716 
717         if (OutsideUsers.empty())
718           continue;
719 
720         // Emit an alloca in the outer region to store the broadcasted
721         // value.
722         const DataLayout &DL = M.getDataLayout();
723         AllocaInst *AllocaI = new AllocaInst(
724             I.getType(), DL.getAllocaAddrSpace(), nullptr,
725             I.getName() + ".seq.output.alloc", &OuterFn->front().front());
726 
727         // Emit a store instruction in the sequential BB to update the
728         // value.
729         new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
730 
731         // Emit a load instruction and replace the use of the output value
732         // with it.
733         for (Instruction *UsrI : OutsideUsers) {
734           LoadInst *LoadI = new LoadInst(
735               I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
736           UsrI->replaceUsesOfWith(&I, LoadI);
737         }
738       }
739 
740       OpenMPIRBuilder::LocationDescription Loc(
741           InsertPointTy(ParentBB, ParentBB->end()), DL);
742       InsertPointTy SeqAfterIP =
743           OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
744 
745       OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
746 
747       BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
748 
749       LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
750                         << "\n");
751     };
752 
753     // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
754     // contained in BB and only separated by instructions that can be
755     // redundantly executed in parallel. The block BB is split before the first
756     // call (in MergableCIs) and after the last so the entire region we merge
757     // into a single parallel region is contained in a single basic block
758     // without any other instructions. We use the OpenMPIRBuilder to outline
759     // that block and call the resulting function via __kmpc_fork_call.
760     auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) {
761       // TODO: Change the interface to allow single CIs expanded, e.g, to
762       // include an outer loop.
763       assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
764 
765       auto Remark = [&](OptimizationRemark OR) {
766         OR << "Parallel region at "
767            << ore::NV("OpenMPParallelMergeFront",
768                       MergableCIs.front()->getDebugLoc())
769            << " merged with parallel regions at ";
770         for (auto *CI : llvm::drop_begin(MergableCIs)) {
771           OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
772           if (CI != MergableCIs.back())
773             OR << ", ";
774         }
775         return OR;
776       };
777 
778       emitRemark<OptimizationRemark>(MergableCIs.front(),
779                                      "OpenMPParallelRegionMerging", Remark);
780 
781       Function *OriginalFn = BB->getParent();
782       LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
783                         << " parallel regions in " << OriginalFn->getName()
784                         << "\n");
785 
786       // Isolate the calls to merge in a separate block.
787       EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
788       BasicBlock *AfterBB =
789           SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
790       StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
791                            "omp.par.merged");
792 
793       assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
794       const DebugLoc DL = BB->getTerminator()->getDebugLoc();
795       BB->getTerminator()->eraseFromParent();
796 
797       // Create sequential regions for sequential instructions that are
798       // in-between mergable parallel regions.
799       for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
800            It != End; ++It) {
801         Instruction *ForkCI = *It;
802         Instruction *NextForkCI = *(It + 1);
803 
804         // Continue if there are not in-between instructions.
805         if (ForkCI->getNextNode() == NextForkCI)
806           continue;
807 
808         CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
809                                NextForkCI->getPrevNode());
810       }
811 
812       OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
813                                                DL);
814       IRBuilder<>::InsertPoint AllocaIP(
815           &OriginalFn->getEntryBlock(),
816           OriginalFn->getEntryBlock().getFirstInsertionPt());
817       // Create the merged parallel region with default proc binding, to
818       // avoid overriding binding settings, and without explicit cancellation.
819       InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
820           Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
821           OMP_PROC_BIND_default, /* IsCancellable */ false);
822       BranchInst::Create(AfterBB, AfterIP.getBlock());
823 
824       // Perform the actual outlining.
825       OMPInfoCache.OMPBuilder.finalize(OriginalFn,
826                                        /* AllowExtractorSinking */ true);
827 
828       Function *OutlinedFn = MergableCIs.front()->getCaller();
829 
830       // Replace the __kmpc_fork_call calls with direct calls to the outlined
831       // callbacks.
832       SmallVector<Value *, 8> Args;
833       for (auto *CI : MergableCIs) {
834         Value *Callee =
835             CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts();
836         FunctionType *FT =
837             cast<FunctionType>(Callee->getType()->getPointerElementType());
838         Args.clear();
839         Args.push_back(OutlinedFn->getArg(0));
840         Args.push_back(OutlinedFn->getArg(1));
841         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
842              U < E; ++U)
843           Args.push_back(CI->getArgOperand(U));
844 
845         CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
846         if (CI->getDebugLoc())
847           NewCI->setDebugLoc(CI->getDebugLoc());
848 
849         // Forward parameter attributes from the callback to the callee.
850         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
851              U < E; ++U)
852           for (const Attribute &A : CI->getAttributes().getParamAttributes(U))
853             NewCI->addParamAttr(
854                 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
855 
856         // Emit an explicit barrier to replace the implicit fork-join barrier.
857         if (CI != MergableCIs.back()) {
858           // TODO: Remove barrier if the merged parallel region includes the
859           // 'nowait' clause.
860           OMPInfoCache.OMPBuilder.createBarrier(
861               InsertPointTy(NewCI->getParent(),
862                             NewCI->getNextNode()->getIterator()),
863               OMPD_parallel);
864         }
865 
866         auto Remark = [&](OptimizationRemark OR) {
867           return OR << "Parallel region at "
868                     << ore::NV("OpenMPParallelMerge", CI->getDebugLoc())
869                     << " merged with "
870                     << ore::NV("OpenMPParallelMergeFront",
871                                MergableCIs.front()->getDebugLoc());
872         };
873         if (CI != MergableCIs.front())
874           emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionMerging",
875                                          Remark);
876 
877         CI->eraseFromParent();
878       }
879 
880       assert(OutlinedFn != OriginalFn && "Outlining failed");
881       CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
882       CGUpdater.reanalyzeFunction(*OriginalFn);
883 
884       NumOpenMPParallelRegionsMerged += MergableCIs.size();
885 
886       return true;
887     };
888 
889     // Helper function that identifes sequences of
890     // __kmpc_fork_call uses in a basic block.
891     auto DetectPRsCB = [&](Use &U, Function &F) {
892       CallInst *CI = getCallIfRegularCall(U, &RFI);
893       BB2PRMap[CI->getParent()].insert(CI);
894 
895       return false;
896     };
897 
898     BB2PRMap.clear();
899     RFI.foreachUse(SCC, DetectPRsCB);
900     SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
901     // Find mergable parallel regions within a basic block that are
902     // safe to merge, that is any in-between instructions can safely
903     // execute in parallel after merging.
904     // TODO: support merging across basic-blocks.
905     for (auto &It : BB2PRMap) {
906       auto &CIs = It.getSecond();
907       if (CIs.size() < 2)
908         continue;
909 
910       BasicBlock *BB = It.getFirst();
911       SmallVector<CallInst *, 4> MergableCIs;
912 
913       /// Returns true if the instruction is mergable, false otherwise.
914       /// A terminator instruction is unmergable by definition since merging
915       /// works within a BB. Instructions before the mergable region are
916       /// mergable if they are not calls to OpenMP runtime functions that may
917       /// set different execution parameters for subsequent parallel regions.
918       /// Instructions in-between parallel regions are mergable if they are not
919       /// calls to any non-intrinsic function since that may call a non-mergable
920       /// OpenMP runtime function.
921       auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
922         // We do not merge across BBs, hence return false (unmergable) if the
923         // instruction is a terminator.
924         if (I.isTerminator())
925           return false;
926 
927         if (!isa<CallInst>(&I))
928           return true;
929 
930         CallInst *CI = cast<CallInst>(&I);
931         if (IsBeforeMergableRegion) {
932           Function *CalledFunction = CI->getCalledFunction();
933           if (!CalledFunction)
934             return false;
935           // Return false (unmergable) if the call before the parallel
936           // region calls an explicit affinity (proc_bind) or number of
937           // threads (num_threads) compiler-generated function. Those settings
938           // may be incompatible with following parallel regions.
939           // TODO: ICV tracking to detect compatibility.
940           for (const auto &RFI : UnmergableCallsInfo) {
941             if (CalledFunction == RFI.Declaration)
942               return false;
943           }
944         } else {
945           // Return false (unmergable) if there is a call instruction
946           // in-between parallel regions when it is not an intrinsic. It
947           // may call an unmergable OpenMP runtime function in its callpath.
948           // TODO: Keep track of possible OpenMP calls in the callpath.
949           if (!isa<IntrinsicInst>(CI))
950             return false;
951         }
952 
953         return true;
954       };
955       // Find maximal number of parallel region CIs that are safe to merge.
956       for (auto It = BB->begin(), End = BB->end(); It != End;) {
957         Instruction &I = *It;
958         ++It;
959 
960         if (CIs.count(&I)) {
961           MergableCIs.push_back(cast<CallInst>(&I));
962           continue;
963         }
964 
965         // Continue expanding if the instruction is mergable.
966         if (IsMergable(I, MergableCIs.empty()))
967           continue;
968 
969         // Forward the instruction iterator to skip the next parallel region
970         // since there is an unmergable instruction which can affect it.
971         for (; It != End; ++It) {
972           Instruction &SkipI = *It;
973           if (CIs.count(&SkipI)) {
974             LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
975                               << " due to " << I << "\n");
976             ++It;
977             break;
978           }
979         }
980 
981         // Store mergable regions found.
982         if (MergableCIs.size() > 1) {
983           MergableCIsVector.push_back(MergableCIs);
984           LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
985                             << " parallel regions in block " << BB->getName()
986                             << " of function " << BB->getParent()->getName()
987                             << "\n";);
988         }
989 
990         MergableCIs.clear();
991       }
992 
993       if (!MergableCIsVector.empty()) {
994         Changed = true;
995 
996         for (auto &MergableCIs : MergableCIsVector)
997           Merge(MergableCIs, BB);
998         MergableCIsVector.clear();
999       }
1000     }
1001 
1002     if (Changed) {
1003       /// Re-collect use for fork calls, emitted barrier calls, and
1004       /// any emitted master/end_master calls.
1005       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1006       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1007       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1008       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1009     }
1010 
1011     return Changed;
1012   }
1013 
1014   /// Try to delete parallel regions if possible.
1015   bool deleteParallelRegions() {
1016     const unsigned CallbackCalleeOperand = 2;
1017 
1018     OMPInformationCache::RuntimeFunctionInfo &RFI =
1019         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1020 
1021     if (!RFI.Declaration)
1022       return false;
1023 
1024     bool Changed = false;
1025     auto DeleteCallCB = [&](Use &U, Function &) {
1026       CallInst *CI = getCallIfRegularCall(U);
1027       if (!CI)
1028         return false;
1029       auto *Fn = dyn_cast<Function>(
1030           CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1031       if (!Fn)
1032         return false;
1033       if (!Fn->onlyReadsMemory())
1034         return false;
1035       if (!Fn->hasFnAttribute(Attribute::WillReturn))
1036         return false;
1037 
1038       LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1039                         << CI->getCaller()->getName() << "\n");
1040 
1041       auto Remark = [&](OptimizationRemark OR) {
1042         return OR << "Parallel region in "
1043                   << ore::NV("OpenMPParallelDelete", CI->getCaller()->getName())
1044                   << " deleted";
1045       };
1046       emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionDeletion",
1047                                      Remark);
1048 
1049       CGUpdater.removeCallSite(*CI);
1050       CI->eraseFromParent();
1051       Changed = true;
1052       ++NumOpenMPParallelRegionsDeleted;
1053       return true;
1054     };
1055 
1056     RFI.foreachUse(SCC, DeleteCallCB);
1057 
1058     return Changed;
1059   }
1060 
1061   /// Try to eliminate runtime calls by reusing existing ones.
1062   bool deduplicateRuntimeCalls() {
1063     bool Changed = false;
1064 
1065     RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1066         OMPRTL_omp_get_num_threads,
1067         OMPRTL_omp_in_parallel,
1068         OMPRTL_omp_get_cancellation,
1069         OMPRTL_omp_get_thread_limit,
1070         OMPRTL_omp_get_supported_active_levels,
1071         OMPRTL_omp_get_level,
1072         OMPRTL_omp_get_ancestor_thread_num,
1073         OMPRTL_omp_get_team_size,
1074         OMPRTL_omp_get_active_level,
1075         OMPRTL_omp_in_final,
1076         OMPRTL_omp_get_proc_bind,
1077         OMPRTL_omp_get_num_places,
1078         OMPRTL_omp_get_num_procs,
1079         OMPRTL_omp_get_place_num,
1080         OMPRTL_omp_get_partition_num_places,
1081         OMPRTL_omp_get_partition_place_nums};
1082 
1083     // Global-tid is handled separately.
1084     SmallSetVector<Value *, 16> GTIdArgs;
1085     collectGlobalThreadIdArguments(GTIdArgs);
1086     LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1087                       << " global thread ID arguments\n");
1088 
1089     for (Function *F : SCC) {
1090       for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1091         Changed |= deduplicateRuntimeCalls(
1092             *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1093 
1094       // __kmpc_global_thread_num is special as we can replace it with an
1095       // argument in enough cases to make it worth trying.
1096       Value *GTIdArg = nullptr;
1097       for (Argument &Arg : F->args())
1098         if (GTIdArgs.count(&Arg)) {
1099           GTIdArg = &Arg;
1100           break;
1101         }
1102       Changed |= deduplicateRuntimeCalls(
1103           *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1104     }
1105 
1106     return Changed;
1107   }
1108 
1109   /// Tries to hide the latency of runtime calls that involve host to
1110   /// device memory transfers by splitting them into their "issue" and "wait"
1111   /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1112   /// moved downards as much as possible. The "issue" issues the memory transfer
1113   /// asynchronously, returning a handle. The "wait" waits in the returned
1114   /// handle for the memory transfer to finish.
1115   bool hideMemTransfersLatency() {
1116     auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1117     bool Changed = false;
1118     auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1119       auto *RTCall = getCallIfRegularCall(U, &RFI);
1120       if (!RTCall)
1121         return false;
1122 
1123       OffloadArray OffloadArrays[3];
1124       if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1125         return false;
1126 
1127       LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1128 
1129       // TODO: Check if can be moved upwards.
1130       bool WasSplit = false;
1131       Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1132       if (WaitMovementPoint)
1133         WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1134 
1135       Changed |= WasSplit;
1136       return WasSplit;
1137     };
1138     RFI.foreachUse(SCC, SplitMemTransfers);
1139 
1140     return Changed;
1141   }
1142 
1143   void analysisGlobalization() {
1144     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1145 
1146     auto CheckGlobalization = [&](Use &U, Function &Decl) {
1147       if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1148         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1149           return ORA
1150                  << "Found thread data sharing on the GPU. "
1151                  << "Expect degraded performance due to data globalization.";
1152         };
1153         emitRemark<OptimizationRemarkAnalysis>(CI, "OpenMPGlobalization",
1154                                                Remark);
1155       }
1156 
1157       return false;
1158     };
1159 
1160     RFI.foreachUse(SCC, CheckGlobalization);
1161   }
1162 
1163   /// Maps the values stored in the offload arrays passed as arguments to
1164   /// \p RuntimeCall into the offload arrays in \p OAs.
1165   bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1166                                 MutableArrayRef<OffloadArray> OAs) {
1167     assert(OAs.size() == 3 && "Need space for three offload arrays!");
1168 
1169     // A runtime call that involves memory offloading looks something like:
1170     // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1171     //   i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1172     // ...)
1173     // So, the idea is to access the allocas that allocate space for these
1174     // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1175     // Therefore:
1176     // i8** %offload_baseptrs.
1177     Value *BasePtrsArg =
1178         RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1179     // i8** %offload_ptrs.
1180     Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1181     // i8** %offload_sizes.
1182     Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1183 
1184     // Get values stored in **offload_baseptrs.
1185     auto *V = getUnderlyingObject(BasePtrsArg);
1186     if (!isa<AllocaInst>(V))
1187       return false;
1188     auto *BasePtrsArray = cast<AllocaInst>(V);
1189     if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1190       return false;
1191 
1192     // Get values stored in **offload_baseptrs.
1193     V = getUnderlyingObject(PtrsArg);
1194     if (!isa<AllocaInst>(V))
1195       return false;
1196     auto *PtrsArray = cast<AllocaInst>(V);
1197     if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1198       return false;
1199 
1200     // Get values stored in **offload_sizes.
1201     V = getUnderlyingObject(SizesArg);
1202     // If it's a [constant] global array don't analyze it.
1203     if (isa<GlobalValue>(V))
1204       return isa<Constant>(V);
1205     if (!isa<AllocaInst>(V))
1206       return false;
1207 
1208     auto *SizesArray = cast<AllocaInst>(V);
1209     if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1210       return false;
1211 
1212     return true;
1213   }
1214 
1215   /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1216   /// For now this is a way to test that the function getValuesInOffloadArrays
1217   /// is working properly.
1218   /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1219   void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1220     assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1221 
1222     LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1223     std::string ValuesStr;
1224     raw_string_ostream Printer(ValuesStr);
1225     std::string Separator = " --- ";
1226 
1227     for (auto *BP : OAs[0].StoredValues) {
1228       BP->print(Printer);
1229       Printer << Separator;
1230     }
1231     LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
1232     ValuesStr.clear();
1233 
1234     for (auto *P : OAs[1].StoredValues) {
1235       P->print(Printer);
1236       Printer << Separator;
1237     }
1238     LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
1239     ValuesStr.clear();
1240 
1241     for (auto *S : OAs[2].StoredValues) {
1242       S->print(Printer);
1243       Printer << Separator;
1244     }
1245     LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
1246   }
1247 
1248   /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1249   /// moved. Returns nullptr if the movement is not possible, or not worth it.
1250   Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1251     // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1252     //  Make it traverse the CFG.
1253 
1254     Instruction *CurrentI = &RuntimeCall;
1255     bool IsWorthIt = false;
1256     while ((CurrentI = CurrentI->getNextNode())) {
1257 
1258       // TODO: Once we detect the regions to be offloaded we should use the
1259       //  alias analysis manager to check if CurrentI may modify one of
1260       //  the offloaded regions.
1261       if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1262         if (IsWorthIt)
1263           return CurrentI;
1264 
1265         return nullptr;
1266       }
1267 
1268       // FIXME: For now if we move it over anything without side effect
1269       //  is worth it.
1270       IsWorthIt = true;
1271     }
1272 
1273     // Return end of BasicBlock.
1274     return RuntimeCall.getParent()->getTerminator();
1275   }
1276 
1277   /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1278   bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1279                                Instruction &WaitMovementPoint) {
1280     // Create stack allocated handle (__tgt_async_info) at the beginning of the
1281     // function. Used for storing information of the async transfer, allowing to
1282     // wait on it later.
1283     auto &IRBuilder = OMPInfoCache.OMPBuilder;
1284     auto *F = RuntimeCall.getCaller();
1285     Instruction *FirstInst = &(F->getEntryBlock().front());
1286     AllocaInst *Handle = new AllocaInst(
1287         IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst);
1288 
1289     // Add "issue" runtime call declaration:
1290     // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1291     //   i8**, i8**, i64*, i64*)
1292     FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1293         M, OMPRTL___tgt_target_data_begin_mapper_issue);
1294 
1295     // Change RuntimeCall call site for its asynchronous version.
1296     SmallVector<Value *, 16> Args;
1297     for (auto &Arg : RuntimeCall.args())
1298       Args.push_back(Arg.get());
1299     Args.push_back(Handle);
1300 
1301     CallInst *IssueCallsite =
1302         CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
1303     RuntimeCall.eraseFromParent();
1304 
1305     // Add "wait" runtime call declaration:
1306     // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1307     FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1308         M, OMPRTL___tgt_target_data_begin_mapper_wait);
1309 
1310     Value *WaitParams[2] = {
1311         IssueCallsite->getArgOperand(
1312             OffloadArray::DeviceIDArgNum), // device_id.
1313         Handle                             // handle to wait on.
1314     };
1315     CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
1316 
1317     return true;
1318   }
1319 
1320   static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1321                                     bool GlobalOnly, bool &SingleChoice) {
1322     if (CurrentIdent == NextIdent)
1323       return CurrentIdent;
1324 
1325     // TODO: Figure out how to actually combine multiple debug locations. For
1326     //       now we just keep an existing one if there is a single choice.
1327     if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1328       SingleChoice = !CurrentIdent;
1329       return NextIdent;
1330     }
1331     return nullptr;
1332   }
1333 
1334   /// Return an `struct ident_t*` value that represents the ones used in the
1335   /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1336   /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1337   /// return value we create one from scratch. We also do not yet combine
1338   /// information, e.g., the source locations, see combinedIdentStruct.
1339   Value *
1340   getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1341                                  Function &F, bool GlobalOnly) {
1342     bool SingleChoice = true;
1343     Value *Ident = nullptr;
1344     auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1345       CallInst *CI = getCallIfRegularCall(U, &RFI);
1346       if (!CI || &F != &Caller)
1347         return false;
1348       Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1349                                   /* GlobalOnly */ true, SingleChoice);
1350       return false;
1351     };
1352     RFI.foreachUse(SCC, CombineIdentStruct);
1353 
1354     if (!Ident || !SingleChoice) {
1355       // The IRBuilder uses the insertion block to get to the module, this is
1356       // unfortunate but we work around it for now.
1357       if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1358         OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1359             &F.getEntryBlock(), F.getEntryBlock().begin()));
1360       // Create a fallback location if non was found.
1361       // TODO: Use the debug locations of the calls instead.
1362       Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr();
1363       Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc);
1364     }
1365     return Ident;
1366   }
1367 
1368   /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1369   /// \p ReplVal if given.
1370   bool deduplicateRuntimeCalls(Function &F,
1371                                OMPInformationCache::RuntimeFunctionInfo &RFI,
1372                                Value *ReplVal = nullptr) {
1373     auto *UV = RFI.getUseVector(F);
1374     if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1375       return false;
1376 
1377     LLVM_DEBUG(
1378         dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1379                << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1380 
1381     assert((!ReplVal || (isa<Argument>(ReplVal) &&
1382                          cast<Argument>(ReplVal)->getParent() == &F)) &&
1383            "Unexpected replacement value!");
1384 
1385     // TODO: Use dominance to find a good position instead.
1386     auto CanBeMoved = [this](CallBase &CB) {
1387       unsigned NumArgs = CB.getNumArgOperands();
1388       if (NumArgs == 0)
1389         return true;
1390       if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1391         return false;
1392       for (unsigned u = 1; u < NumArgs; ++u)
1393         if (isa<Instruction>(CB.getArgOperand(u)))
1394           return false;
1395       return true;
1396     };
1397 
1398     if (!ReplVal) {
1399       for (Use *U : *UV)
1400         if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1401           if (!CanBeMoved(*CI))
1402             continue;
1403 
1404           auto Remark = [&](OptimizationRemark OR) {
1405             return OR << "OpenMP runtime call "
1406                       << ore::NV("OpenMPOptRuntime", RFI.Name)
1407                       << " moved to beginning of OpenMP region";
1408           };
1409           emitRemark<OptimizationRemark>(&F, "OpenMPRuntimeCodeMotion", Remark);
1410 
1411           CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
1412           ReplVal = CI;
1413           break;
1414         }
1415       if (!ReplVal)
1416         return false;
1417     }
1418 
1419     // If we use a call as a replacement value we need to make sure the ident is
1420     // valid at the new location. For now we just pick a global one, either
1421     // existing and used by one of the calls, or created from scratch.
1422     if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1423       if (CI->getNumArgOperands() > 0 &&
1424           CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1425         Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1426                                                       /* GlobalOnly */ true);
1427         CI->setArgOperand(0, Ident);
1428       }
1429     }
1430 
1431     bool Changed = false;
1432     auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1433       CallInst *CI = getCallIfRegularCall(U, &RFI);
1434       if (!CI || CI == ReplVal || &F != &Caller)
1435         return false;
1436       assert(CI->getCaller() == &F && "Unexpected call!");
1437 
1438       auto Remark = [&](OptimizationRemark OR) {
1439         return OR << "OpenMP runtime call "
1440                   << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated";
1441       };
1442       emitRemark<OptimizationRemark>(&F, "OpenMPRuntimeDeduplicated", Remark);
1443 
1444       CGUpdater.removeCallSite(*CI);
1445       CI->replaceAllUsesWith(ReplVal);
1446       CI->eraseFromParent();
1447       ++NumOpenMPRuntimeCallsDeduplicated;
1448       Changed = true;
1449       return true;
1450     };
1451     RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1452 
1453     return Changed;
1454   }
1455 
1456   /// Collect arguments that represent the global thread id in \p GTIdArgs.
1457   void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1458     // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1459     //       initialization. We could define an AbstractAttribute instead and
1460     //       run the Attributor here once it can be run as an SCC pass.
1461 
1462     // Helper to check the argument \p ArgNo at all call sites of \p F for
1463     // a GTId.
1464     auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1465       if (!F.hasLocalLinkage())
1466         return false;
1467       for (Use &U : F.uses()) {
1468         if (CallInst *CI = getCallIfRegularCall(U)) {
1469           Value *ArgOp = CI->getArgOperand(ArgNo);
1470           if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1471               getCallIfRegularCall(
1472                   *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1473             continue;
1474         }
1475         return false;
1476       }
1477       return true;
1478     };
1479 
1480     // Helper to identify uses of a GTId as GTId arguments.
1481     auto AddUserArgs = [&](Value &GTId) {
1482       for (Use &U : GTId.uses())
1483         if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1484           if (CI->isArgOperand(&U))
1485             if (Function *Callee = CI->getCalledFunction())
1486               if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1487                 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1488     };
1489 
1490     // The argument users of __kmpc_global_thread_num calls are GTIds.
1491     OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1492         OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1493 
1494     GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1495       if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1496         AddUserArgs(*CI);
1497       return false;
1498     });
1499 
1500     // Transitively search for more arguments by looking at the users of the
1501     // ones we know already. During the search the GTIdArgs vector is extended
1502     // so we cannot cache the size nor can we use a range based for.
1503     for (unsigned u = 0; u < GTIdArgs.size(); ++u)
1504       AddUserArgs(*GTIdArgs[u]);
1505   }
1506 
1507   /// Kernel (=GPU) optimizations and utility functions
1508   ///
1509   ///{{
1510 
1511   /// Check if \p F is a kernel, hence entry point for target offloading.
1512   bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
1513 
1514   /// Cache to remember the unique kernel for a function.
1515   DenseMap<Function *, Optional<Kernel>> UniqueKernelMap;
1516 
1517   /// Find the unique kernel that will execute \p F, if any.
1518   Kernel getUniqueKernelFor(Function &F);
1519 
1520   /// Find the unique kernel that will execute \p I, if any.
1521   Kernel getUniqueKernelFor(Instruction &I) {
1522     return getUniqueKernelFor(*I.getFunction());
1523   }
1524 
1525   /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1526   /// the cases we can avoid taking the address of a function.
1527   bool rewriteDeviceCodeStateMachine();
1528 
1529   ///
1530   ///}}
1531 
1532   /// Emit a remark generically
1533   ///
1534   /// This template function can be used to generically emit a remark. The
1535   /// RemarkKind should be one of the following:
1536   ///   - OptimizationRemark to indicate a successful optimization attempt
1537   ///   - OptimizationRemarkMissed to report a failed optimization attempt
1538   ///   - OptimizationRemarkAnalysis to provide additional information about an
1539   ///     optimization attempt
1540   ///
1541   /// The remark is built using a callback function provided by the caller that
1542   /// takes a RemarkKind as input and returns a RemarkKind.
1543   template <typename RemarkKind, typename RemarkCallBack>
1544   void emitRemark(Instruction *I, StringRef RemarkName,
1545                   RemarkCallBack &&RemarkCB) const {
1546     Function *F = I->getParent()->getParent();
1547     auto &ORE = OREGetter(F);
1548 
1549     ORE.emit([&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
1550   }
1551 
1552   /// Emit a remark on a function.
1553   template <typename RemarkKind, typename RemarkCallBack>
1554   void emitRemark(Function *F, StringRef RemarkName,
1555                   RemarkCallBack &&RemarkCB) const {
1556     auto &ORE = OREGetter(F);
1557 
1558     ORE.emit([&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
1559   }
1560 
1561   /// The underlying module.
1562   Module &M;
1563 
1564   /// The SCC we are operating on.
1565   SmallVectorImpl<Function *> &SCC;
1566 
1567   /// Callback to update the call graph, the first argument is a removed call,
1568   /// the second an optional replacement call.
1569   CallGraphUpdater &CGUpdater;
1570 
1571   /// Callback to get an OptimizationRemarkEmitter from a Function *
1572   OptimizationRemarkGetter OREGetter;
1573 
1574   /// OpenMP-specific information cache. Also Used for Attributor runs.
1575   OMPInformationCache &OMPInfoCache;
1576 
1577   /// Attributor instance.
1578   Attributor &A;
1579 
1580   /// Helper function to run Attributor on SCC.
1581   bool runAttributor() {
1582     if (SCC.empty())
1583       return false;
1584 
1585     registerAAs();
1586 
1587     ChangeStatus Changed = A.run();
1588 
1589     LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
1590                       << " functions, result: " << Changed << ".\n");
1591 
1592     return Changed == ChangeStatus::CHANGED;
1593   }
1594 
1595   /// Populate the Attributor with abstract attribute opportunities in the
1596   /// function.
1597   void registerAAs() {
1598     if (SCC.empty())
1599       return;
1600 
1601     // Create CallSite AA for all Getters.
1602     for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
1603       auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
1604 
1605       auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
1606 
1607       auto CreateAA = [&](Use &U, Function &Caller) {
1608         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
1609         if (!CI)
1610           return false;
1611 
1612         auto &CB = cast<CallBase>(*CI);
1613 
1614         IRPosition CBPos = IRPosition::callsite_function(CB);
1615         A.getOrCreateAAFor<AAICVTracker>(CBPos);
1616         return false;
1617       };
1618 
1619       GetterRFI.foreachUse(SCC, CreateAA);
1620     }
1621     auto &GlobalizationRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1622     auto CreateAA = [&](Use &U, Function &F) {
1623       A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
1624       return false;
1625     };
1626     GlobalizationRFI.foreachUse(SCC, CreateAA);
1627 
1628     // Create an ExecutionDomain AA for every function and a HeapToStack AA for
1629     // every function if there is a device kernel.
1630     for (auto *F : SCC) {
1631       if (!F->isDeclaration())
1632         A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F));
1633       if (!OMPInfoCache.Kernels.empty())
1634         A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F));
1635     }
1636   }
1637 };
1638 
1639 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
1640   if (!OMPInfoCache.ModuleSlice.count(&F))
1641     return nullptr;
1642 
1643   // Use a scope to keep the lifetime of the CachedKernel short.
1644   {
1645     Optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
1646     if (CachedKernel)
1647       return *CachedKernel;
1648 
1649     // TODO: We should use an AA to create an (optimistic and callback
1650     //       call-aware) call graph. For now we stick to simple patterns that
1651     //       are less powerful, basically the worst fixpoint.
1652     if (isKernel(F)) {
1653       CachedKernel = Kernel(&F);
1654       return *CachedKernel;
1655     }
1656 
1657     CachedKernel = nullptr;
1658     if (!F.hasLocalLinkage()) {
1659 
1660       // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
1661       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1662         return ORA
1663                << "[OMP100] Potentially unknown OpenMP target region caller";
1664       };
1665       emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
1666 
1667       return nullptr;
1668     }
1669   }
1670 
1671   auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
1672     if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
1673       // Allow use in equality comparisons.
1674       if (Cmp->isEquality())
1675         return getUniqueKernelFor(*Cmp);
1676       return nullptr;
1677     }
1678     if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
1679       // Allow direct calls.
1680       if (CB->isCallee(&U))
1681         return getUniqueKernelFor(*CB);
1682 
1683       OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1684           OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1685       // Allow the use in __kmpc_parallel_51 calls.
1686       if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
1687         return getUniqueKernelFor(*CB);
1688       return nullptr;
1689     }
1690     // Disallow every other use.
1691     return nullptr;
1692   };
1693 
1694   // TODO: In the future we want to track more than just a unique kernel.
1695   SmallPtrSet<Kernel, 2> PotentialKernels;
1696   OMPInformationCache::foreachUse(F, [&](const Use &U) {
1697     PotentialKernels.insert(GetUniqueKernelForUse(U));
1698   });
1699 
1700   Kernel K = nullptr;
1701   if (PotentialKernels.size() == 1)
1702     K = *PotentialKernels.begin();
1703 
1704   // Cache the result.
1705   UniqueKernelMap[&F] = K;
1706 
1707   return K;
1708 }
1709 
1710 bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
1711   OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1712       OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1713 
1714   bool Changed = false;
1715   if (!KernelParallelRFI)
1716     return Changed;
1717 
1718   for (Function *F : SCC) {
1719 
1720     // Check if the function is a use in a __kmpc_parallel_51 call at
1721     // all.
1722     bool UnknownUse = false;
1723     bool KernelParallelUse = false;
1724     unsigned NumDirectCalls = 0;
1725 
1726     SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
1727     OMPInformationCache::foreachUse(*F, [&](Use &U) {
1728       if (auto *CB = dyn_cast<CallBase>(U.getUser()))
1729         if (CB->isCallee(&U)) {
1730           ++NumDirectCalls;
1731           return;
1732         }
1733 
1734       if (isa<ICmpInst>(U.getUser())) {
1735         ToBeReplacedStateMachineUses.push_back(&U);
1736         return;
1737       }
1738 
1739       // Find wrapper functions that represent parallel kernels.
1740       CallInst *CI =
1741           OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
1742       const unsigned int WrapperFunctionArgNo = 6;
1743       if (!KernelParallelUse && CI &&
1744           CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
1745         KernelParallelUse = true;
1746         ToBeReplacedStateMachineUses.push_back(&U);
1747         return;
1748       }
1749       UnknownUse = true;
1750     });
1751 
1752     // Do not emit a remark if we haven't seen a __kmpc_parallel_51
1753     // use.
1754     if (!KernelParallelUse)
1755       continue;
1756 
1757     {
1758       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1759         return ORA << "Found a parallel region that is called in a target "
1760                       "region but not part of a combined target construct nor "
1761                       "nested inside a target construct without intermediate "
1762                       "code. This can lead to excessive register usage for "
1763                       "unrelated target regions in the same translation unit "
1764                       "due to spurious call edges assumed by ptxas.";
1765       };
1766       emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPParallelRegionInNonSPMD",
1767                                              Remark);
1768     }
1769 
1770     // If this ever hits, we should investigate.
1771     // TODO: Checking the number of uses is not a necessary restriction and
1772     // should be lifted.
1773     if (UnknownUse || NumDirectCalls != 1 ||
1774         ToBeReplacedStateMachineUses.size() != 2) {
1775       {
1776         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1777           return ORA << "Parallel region is used in "
1778                      << (UnknownUse ? "unknown" : "unexpected")
1779                      << " ways; will not attempt to rewrite the state machine.";
1780         };
1781         emitRemark<OptimizationRemarkAnalysis>(
1782             F, "OpenMPParallelRegionInNonSPMD", Remark);
1783       }
1784       continue;
1785     }
1786 
1787     // Even if we have __kmpc_parallel_51 calls, we (for now) give
1788     // up if the function is not called from a unique kernel.
1789     Kernel K = getUniqueKernelFor(*F);
1790     if (!K) {
1791       {
1792         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1793           return ORA << "Parallel region is not known to be called from a "
1794                         "unique single target region, maybe the surrounding "
1795                         "function has external linkage?; will not attempt to "
1796                         "rewrite the state machine use.";
1797         };
1798         emitRemark<OptimizationRemarkAnalysis>(
1799             F, "OpenMPParallelRegionInMultipleKernesl", Remark);
1800       }
1801       continue;
1802     }
1803 
1804     // We now know F is a parallel body function called only from the kernel K.
1805     // We also identified the state machine uses in which we replace the
1806     // function pointer by a new global symbol for identification purposes. This
1807     // ensures only direct calls to the function are left.
1808 
1809     {
1810       auto RemarkParalleRegion = [&](OptimizationRemarkAnalysis ORA) {
1811         return ORA << "Specialize parallel region that is only reached from a "
1812                       "single target region to avoid spurious call edges and "
1813                       "excessive register usage in other target regions. "
1814                       "(parallel region ID: "
1815                    << ore::NV("OpenMPParallelRegion", F->getName())
1816                    << ", kernel ID: "
1817                    << ore::NV("OpenMPTargetRegion", K->getName()) << ")";
1818       };
1819       emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPParallelRegionInNonSPMD",
1820                                              RemarkParalleRegion);
1821       auto RemarkKernel = [&](OptimizationRemarkAnalysis ORA) {
1822         return ORA << "Target region containing the parallel region that is "
1823                       "specialized. (parallel region ID: "
1824                    << ore::NV("OpenMPParallelRegion", F->getName())
1825                    << ", kernel ID: "
1826                    << ore::NV("OpenMPTargetRegion", K->getName()) << ")";
1827       };
1828       emitRemark<OptimizationRemarkAnalysis>(K, "OpenMPParallelRegionInNonSPMD",
1829                                              RemarkKernel);
1830     }
1831 
1832     Module &M = *F->getParent();
1833     Type *Int8Ty = Type::getInt8Ty(M.getContext());
1834 
1835     auto *ID = new GlobalVariable(
1836         M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
1837         UndefValue::get(Int8Ty), F->getName() + ".ID");
1838 
1839     for (Use *U : ToBeReplacedStateMachineUses)
1840       U->set(ConstantExpr::getBitCast(ID, U->get()->getType()));
1841 
1842     ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
1843 
1844     Changed = true;
1845   }
1846 
1847   return Changed;
1848 }
1849 
1850 /// Abstract Attribute for tracking ICV values.
1851 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
1852   using Base = StateWrapper<BooleanState, AbstractAttribute>;
1853   AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
1854 
1855   void initialize(Attributor &A) override {
1856     Function *F = getAnchorScope();
1857     if (!F || !A.isFunctionIPOAmendable(*F))
1858       indicatePessimisticFixpoint();
1859   }
1860 
1861   /// Returns true if value is assumed to be tracked.
1862   bool isAssumedTracked() const { return getAssumed(); }
1863 
1864   /// Returns true if value is known to be tracked.
1865   bool isKnownTracked() const { return getAssumed(); }
1866 
1867   /// Create an abstract attribute biew for the position \p IRP.
1868   static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
1869 
1870   /// Return the value with which \p I can be replaced for specific \p ICV.
1871   virtual Optional<Value *> getReplacementValue(InternalControlVar ICV,
1872                                                 const Instruction *I,
1873                                                 Attributor &A) const {
1874     return None;
1875   }
1876 
1877   /// Return an assumed unique ICV value if a single candidate is found. If
1878   /// there cannot be one, return a nullptr. If it is not clear yet, return the
1879   /// Optional::NoneType.
1880   virtual Optional<Value *>
1881   getUniqueReplacementValue(InternalControlVar ICV) const = 0;
1882 
1883   // Currently only nthreads is being tracked.
1884   // this array will only grow with time.
1885   InternalControlVar TrackableICVs[1] = {ICV_nthreads};
1886 
1887   /// See AbstractAttribute::getName()
1888   const std::string getName() const override { return "AAICVTracker"; }
1889 
1890   /// See AbstractAttribute::getIdAddr()
1891   const char *getIdAddr() const override { return &ID; }
1892 
1893   /// This function should return true if the type of the \p AA is AAICVTracker
1894   static bool classof(const AbstractAttribute *AA) {
1895     return (AA->getIdAddr() == &ID);
1896   }
1897 
1898   static const char ID;
1899 };
1900 
1901 struct AAICVTrackerFunction : public AAICVTracker {
1902   AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
1903       : AAICVTracker(IRP, A) {}
1904 
1905   // FIXME: come up with better string.
1906   const std::string getAsStr() const override { return "ICVTrackerFunction"; }
1907 
1908   // FIXME: come up with some stats.
1909   void trackStatistics() const override {}
1910 
1911   /// We don't manifest anything for this AA.
1912   ChangeStatus manifest(Attributor &A) override {
1913     return ChangeStatus::UNCHANGED;
1914   }
1915 
1916   // Map of ICV to their values at specific program point.
1917   EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
1918                   InternalControlVar::ICV___last>
1919       ICVReplacementValuesMap;
1920 
1921   ChangeStatus updateImpl(Attributor &A) override {
1922     ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
1923 
1924     Function *F = getAnchorScope();
1925 
1926     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
1927 
1928     for (InternalControlVar ICV : TrackableICVs) {
1929       auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
1930 
1931       auto &ValuesMap = ICVReplacementValuesMap[ICV];
1932       auto TrackValues = [&](Use &U, Function &) {
1933         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
1934         if (!CI)
1935           return false;
1936 
1937         // FIXME: handle setters with more that 1 arguments.
1938         /// Track new value.
1939         if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
1940           HasChanged = ChangeStatus::CHANGED;
1941 
1942         return false;
1943       };
1944 
1945       auto CallCheck = [&](Instruction &I) {
1946         Optional<Value *> ReplVal = getValueForCall(A, &I, ICV);
1947         if (ReplVal.hasValue() &&
1948             ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
1949           HasChanged = ChangeStatus::CHANGED;
1950 
1951         return true;
1952       };
1953 
1954       // Track all changes of an ICV.
1955       SetterRFI.foreachUse(TrackValues, F);
1956 
1957       A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
1958                                 /* CheckBBLivenessOnly */ true);
1959 
1960       /// TODO: Figure out a way to avoid adding entry in
1961       /// ICVReplacementValuesMap
1962       Instruction *Entry = &F->getEntryBlock().front();
1963       if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
1964         ValuesMap.insert(std::make_pair(Entry, nullptr));
1965     }
1966 
1967     return HasChanged;
1968   }
1969 
1970   /// Hepler to check if \p I is a call and get the value for it if it is
1971   /// unique.
1972   Optional<Value *> getValueForCall(Attributor &A, const Instruction *I,
1973                                     InternalControlVar &ICV) const {
1974 
1975     const auto *CB = dyn_cast<CallBase>(I);
1976     if (!CB || CB->hasFnAttr("no_openmp") ||
1977         CB->hasFnAttr("no_openmp_routines"))
1978       return None;
1979 
1980     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
1981     auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
1982     auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
1983     Function *CalledFunction = CB->getCalledFunction();
1984 
1985     // Indirect call, assume ICV changes.
1986     if (CalledFunction == nullptr)
1987       return nullptr;
1988     if (CalledFunction == GetterRFI.Declaration)
1989       return None;
1990     if (CalledFunction == SetterRFI.Declaration) {
1991       if (ICVReplacementValuesMap[ICV].count(I))
1992         return ICVReplacementValuesMap[ICV].lookup(I);
1993 
1994       return nullptr;
1995     }
1996 
1997     // Since we don't know, assume it changes the ICV.
1998     if (CalledFunction->isDeclaration())
1999       return nullptr;
2000 
2001     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2002         *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2003 
2004     if (ICVTrackingAA.isAssumedTracked())
2005       return ICVTrackingAA.getUniqueReplacementValue(ICV);
2006 
2007     // If we don't know, assume it changes.
2008     return nullptr;
2009   }
2010 
2011   // We don't check unique value for a function, so return None.
2012   Optional<Value *>
2013   getUniqueReplacementValue(InternalControlVar ICV) const override {
2014     return None;
2015   }
2016 
2017   /// Return the value with which \p I can be replaced for specific \p ICV.
2018   Optional<Value *> getReplacementValue(InternalControlVar ICV,
2019                                         const Instruction *I,
2020                                         Attributor &A) const override {
2021     const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2022     if (ValuesMap.count(I))
2023       return ValuesMap.lookup(I);
2024 
2025     SmallVector<const Instruction *, 16> Worklist;
2026     SmallPtrSet<const Instruction *, 16> Visited;
2027     Worklist.push_back(I);
2028 
2029     Optional<Value *> ReplVal;
2030 
2031     while (!Worklist.empty()) {
2032       const Instruction *CurrInst = Worklist.pop_back_val();
2033       if (!Visited.insert(CurrInst).second)
2034         continue;
2035 
2036       const BasicBlock *CurrBB = CurrInst->getParent();
2037 
2038       // Go up and look for all potential setters/calls that might change the
2039       // ICV.
2040       while ((CurrInst = CurrInst->getPrevNode())) {
2041         if (ValuesMap.count(CurrInst)) {
2042           Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2043           // Unknown value, track new.
2044           if (!ReplVal.hasValue()) {
2045             ReplVal = NewReplVal;
2046             break;
2047           }
2048 
2049           // If we found a new value, we can't know the icv value anymore.
2050           if (NewReplVal.hasValue())
2051             if (ReplVal != NewReplVal)
2052               return nullptr;
2053 
2054           break;
2055         }
2056 
2057         Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV);
2058         if (!NewReplVal.hasValue())
2059           continue;
2060 
2061         // Unknown value, track new.
2062         if (!ReplVal.hasValue()) {
2063           ReplVal = NewReplVal;
2064           break;
2065         }
2066 
2067         // if (NewReplVal.hasValue())
2068         // We found a new value, we can't know the icv value anymore.
2069         if (ReplVal != NewReplVal)
2070           return nullptr;
2071       }
2072 
2073       // If we are in the same BB and we have a value, we are done.
2074       if (CurrBB == I->getParent() && ReplVal.hasValue())
2075         return ReplVal;
2076 
2077       // Go through all predecessors and add terminators for analysis.
2078       for (const BasicBlock *Pred : predecessors(CurrBB))
2079         if (const Instruction *Terminator = Pred->getTerminator())
2080           Worklist.push_back(Terminator);
2081     }
2082 
2083     return ReplVal;
2084   }
2085 };
2086 
2087 struct AAICVTrackerFunctionReturned : AAICVTracker {
2088   AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2089       : AAICVTracker(IRP, A) {}
2090 
2091   // FIXME: come up with better string.
2092   const std::string getAsStr() const override {
2093     return "ICVTrackerFunctionReturned";
2094   }
2095 
2096   // FIXME: come up with some stats.
2097   void trackStatistics() const override {}
2098 
2099   /// We don't manifest anything for this AA.
2100   ChangeStatus manifest(Attributor &A) override {
2101     return ChangeStatus::UNCHANGED;
2102   }
2103 
2104   // Map of ICV to their values at specific program point.
2105   EnumeratedArray<Optional<Value *>, InternalControlVar,
2106                   InternalControlVar::ICV___last>
2107       ICVReplacementValuesMap;
2108 
2109   /// Return the value with which \p I can be replaced for specific \p ICV.
2110   Optional<Value *>
2111   getUniqueReplacementValue(InternalControlVar ICV) const override {
2112     return ICVReplacementValuesMap[ICV];
2113   }
2114 
2115   ChangeStatus updateImpl(Attributor &A) override {
2116     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2117     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2118         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2119 
2120     if (!ICVTrackingAA.isAssumedTracked())
2121       return indicatePessimisticFixpoint();
2122 
2123     for (InternalControlVar ICV : TrackableICVs) {
2124       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2125       Optional<Value *> UniqueICVValue;
2126 
2127       auto CheckReturnInst = [&](Instruction &I) {
2128         Optional<Value *> NewReplVal =
2129             ICVTrackingAA.getReplacementValue(ICV, &I, A);
2130 
2131         // If we found a second ICV value there is no unique returned value.
2132         if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal)
2133           return false;
2134 
2135         UniqueICVValue = NewReplVal;
2136 
2137         return true;
2138       };
2139 
2140       if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2141                                      /* CheckBBLivenessOnly */ true))
2142         UniqueICVValue = nullptr;
2143 
2144       if (UniqueICVValue == ReplVal)
2145         continue;
2146 
2147       ReplVal = UniqueICVValue;
2148       Changed = ChangeStatus::CHANGED;
2149     }
2150 
2151     return Changed;
2152   }
2153 };
2154 
2155 struct AAICVTrackerCallSite : AAICVTracker {
2156   AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2157       : AAICVTracker(IRP, A) {}
2158 
2159   void initialize(Attributor &A) override {
2160     Function *F = getAnchorScope();
2161     if (!F || !A.isFunctionIPOAmendable(*F))
2162       indicatePessimisticFixpoint();
2163 
2164     // We only initialize this AA for getters, so we need to know which ICV it
2165     // gets.
2166     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2167     for (InternalControlVar ICV : TrackableICVs) {
2168       auto ICVInfo = OMPInfoCache.ICVs[ICV];
2169       auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2170       if (Getter.Declaration == getAssociatedFunction()) {
2171         AssociatedICV = ICVInfo.Kind;
2172         return;
2173       }
2174     }
2175 
2176     /// Unknown ICV.
2177     indicatePessimisticFixpoint();
2178   }
2179 
2180   ChangeStatus manifest(Attributor &A) override {
2181     if (!ReplVal.hasValue() || !ReplVal.getValue())
2182       return ChangeStatus::UNCHANGED;
2183 
2184     A.changeValueAfterManifest(*getCtxI(), **ReplVal);
2185     A.deleteAfterManifest(*getCtxI());
2186 
2187     return ChangeStatus::CHANGED;
2188   }
2189 
2190   // FIXME: come up with better string.
2191   const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
2192 
2193   // FIXME: come up with some stats.
2194   void trackStatistics() const override {}
2195 
2196   InternalControlVar AssociatedICV;
2197   Optional<Value *> ReplVal;
2198 
2199   ChangeStatus updateImpl(Attributor &A) override {
2200     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2201         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2202 
2203     // We don't have any information, so we assume it changes the ICV.
2204     if (!ICVTrackingAA.isAssumedTracked())
2205       return indicatePessimisticFixpoint();
2206 
2207     Optional<Value *> NewReplVal =
2208         ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
2209 
2210     if (ReplVal == NewReplVal)
2211       return ChangeStatus::UNCHANGED;
2212 
2213     ReplVal = NewReplVal;
2214     return ChangeStatus::CHANGED;
2215   }
2216 
2217   // Return the value with which associated value can be replaced for specific
2218   // \p ICV.
2219   Optional<Value *>
2220   getUniqueReplacementValue(InternalControlVar ICV) const override {
2221     return ReplVal;
2222   }
2223 };
2224 
2225 struct AAICVTrackerCallSiteReturned : AAICVTracker {
2226   AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2227       : AAICVTracker(IRP, A) {}
2228 
2229   // FIXME: come up with better string.
2230   const std::string getAsStr() const override {
2231     return "ICVTrackerCallSiteReturned";
2232   }
2233 
2234   // FIXME: come up with some stats.
2235   void trackStatistics() const override {}
2236 
2237   /// We don't manifest anything for this AA.
2238   ChangeStatus manifest(Attributor &A) override {
2239     return ChangeStatus::UNCHANGED;
2240   }
2241 
2242   // Map of ICV to their values at specific program point.
2243   EnumeratedArray<Optional<Value *>, InternalControlVar,
2244                   InternalControlVar::ICV___last>
2245       ICVReplacementValuesMap;
2246 
2247   /// Return the value with which associated value can be replaced for specific
2248   /// \p ICV.
2249   Optional<Value *>
2250   getUniqueReplacementValue(InternalControlVar ICV) const override {
2251     return ICVReplacementValuesMap[ICV];
2252   }
2253 
2254   ChangeStatus updateImpl(Attributor &A) override {
2255     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2256     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2257         *this, IRPosition::returned(*getAssociatedFunction()),
2258         DepClassTy::REQUIRED);
2259 
2260     // We don't have any information, so we assume it changes the ICV.
2261     if (!ICVTrackingAA.isAssumedTracked())
2262       return indicatePessimisticFixpoint();
2263 
2264     for (InternalControlVar ICV : TrackableICVs) {
2265       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2266       Optional<Value *> NewReplVal =
2267           ICVTrackingAA.getUniqueReplacementValue(ICV);
2268 
2269       if (ReplVal == NewReplVal)
2270         continue;
2271 
2272       ReplVal = NewReplVal;
2273       Changed = ChangeStatus::CHANGED;
2274     }
2275     return Changed;
2276   }
2277 };
2278 
2279 struct AAExecutionDomainFunction : public AAExecutionDomain {
2280   AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2281       : AAExecutionDomain(IRP, A) {}
2282 
2283   const std::string getAsStr() const override {
2284     return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) +
2285            "/" + std::to_string(NumBBs) + " BBs thread 0 only.";
2286   }
2287 
2288   /// See AbstractAttribute::trackStatistics().
2289   void trackStatistics() const override {}
2290 
2291   void initialize(Attributor &A) override {
2292     Function *F = getAnchorScope();
2293     for (const auto &BB : *F)
2294       SingleThreadedBBs.insert(&BB);
2295     NumBBs = SingleThreadedBBs.size();
2296   }
2297 
2298   ChangeStatus manifest(Attributor &A) override {
2299     LLVM_DEBUG({
2300       for (const BasicBlock *BB : SingleThreadedBBs)
2301         dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2302                << BB->getName() << " is executed by a single thread.\n";
2303     });
2304     return ChangeStatus::UNCHANGED;
2305   }
2306 
2307   ChangeStatus updateImpl(Attributor &A) override;
2308 
2309   /// Check if an instruction is executed by a single thread.
2310   bool isExecutedByInitialThreadOnly(const Instruction &I) const override {
2311     return isExecutedByInitialThreadOnly(*I.getParent());
2312   }
2313 
2314   bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2315     return SingleThreadedBBs.contains(&BB);
2316   }
2317 
2318   /// Set of basic blocks that are executed by a single thread.
2319   DenseSet<const BasicBlock *> SingleThreadedBBs;
2320 
2321   /// Total number of basic blocks in this function.
2322   long unsigned NumBBs;
2323 };
2324 
2325 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
2326   Function *F = getAnchorScope();
2327   ReversePostOrderTraversal<Function *> RPOT(F);
2328   auto NumSingleThreadedBBs = SingleThreadedBBs.size();
2329 
2330   bool AllCallSitesKnown;
2331   auto PredForCallSite = [&](AbstractCallSite ACS) {
2332     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
2333         *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
2334         DepClassTy::REQUIRED);
2335     return ExecutionDomainAA.isExecutedByInitialThreadOnly(
2336         *ACS.getInstruction());
2337   };
2338 
2339   if (!A.checkForAllCallSites(PredForCallSite, *this,
2340                               /* RequiresAllCallSites */ true,
2341                               AllCallSitesKnown))
2342     SingleThreadedBBs.erase(&F->getEntryBlock());
2343 
2344   // Check if the edge into the successor block compares a thread-id function to
2345   // a constant zero.
2346   // TODO: Use AAValueSimplify to simplify and propogate constants.
2347   // TODO: Check more than a single use for thread ID's.
2348   auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
2349     if (!Edge || !Edge->isConditional())
2350       return false;
2351     if (Edge->getSuccessor(0) != SuccessorBB)
2352       return false;
2353 
2354     auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2355     if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2356       return false;
2357 
2358     // Temporarily match the pattern generated by clang for teams regions.
2359     // TODO: Remove this once the new runtime is in place.
2360     ConstantInt *One, *NegOne;
2361     CmpInst::Predicate Pred;
2362     auto &&m_ThreadID = m_Intrinsic<Intrinsic::nvvm_read_ptx_sreg_tid_x>();
2363     auto &&m_WarpSize = m_Intrinsic<Intrinsic::nvvm_read_ptx_sreg_warpsize>();
2364     auto &&m_BlockSize = m_Intrinsic<Intrinsic::nvvm_read_ptx_sreg_ntid_x>();
2365     if (match(Cmp, m_Cmp(Pred, m_ThreadID,
2366                          m_And(m_Sub(m_BlockSize, m_ConstantInt(One)),
2367                                m_Xor(m_Sub(m_WarpSize, m_ConstantInt(One)),
2368                                      m_ConstantInt(NegOne))))))
2369       if (One->isOne() && NegOne->isMinusOne() &&
2370           Pred == CmpInst::Predicate::ICMP_EQ)
2371         return true;
2372 
2373     ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2374     if (!C || !C->isZero())
2375       return false;
2376 
2377     if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2378       if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2379         return true;
2380     if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2381       if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2382         return true;
2383 
2384     return false;
2385   };
2386 
2387   // Merge all the predecessor states into the current basic block. A basic
2388   // block is executed by a single thread if all of its predecessors are.
2389   auto MergePredecessorStates = [&](BasicBlock *BB) {
2390     if (pred_begin(BB) == pred_end(BB))
2391       return SingleThreadedBBs.contains(BB);
2392 
2393     bool IsInitialThread = true;
2394     for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB);
2395          PredBB != PredEndBB; ++PredBB) {
2396       if (!IsInitialThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()),
2397                               BB))
2398         IsInitialThread &= SingleThreadedBBs.contains(*PredBB);
2399     }
2400 
2401     return IsInitialThread;
2402   };
2403 
2404   for (auto *BB : RPOT) {
2405     if (!MergePredecessorStates(BB))
2406       SingleThreadedBBs.erase(BB);
2407   }
2408 
2409   return (NumSingleThreadedBBs == SingleThreadedBBs.size())
2410              ? ChangeStatus::UNCHANGED
2411              : ChangeStatus::CHANGED;
2412 }
2413 
2414 /// Try to replace memory allocation calls called by a single thread with a
2415 /// static buffer of shared memory.
2416 struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
2417   using Base = StateWrapper<BooleanState, AbstractAttribute>;
2418   AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2419 
2420   /// Create an abstract attribute view for the position \p IRP.
2421   static AAHeapToShared &createForPosition(const IRPosition &IRP,
2422                                            Attributor &A);
2423 
2424   /// See AbstractAttribute::getName().
2425   const std::string getName() const override { return "AAHeapToShared"; }
2426 
2427   /// See AbstractAttribute::getIdAddr().
2428   const char *getIdAddr() const override { return &ID; }
2429 
2430   /// This function should return true if the type of the \p AA is
2431   /// AAHeapToShared.
2432   static bool classof(const AbstractAttribute *AA) {
2433     return (AA->getIdAddr() == &ID);
2434   }
2435 
2436   /// Unique ID (due to the unique address)
2437   static const char ID;
2438 };
2439 
2440 struct AAHeapToSharedFunction : public AAHeapToShared {
2441   AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
2442       : AAHeapToShared(IRP, A) {}
2443 
2444   const std::string getAsStr() const override {
2445     return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
2446            " malloc calls eligible.";
2447   }
2448 
2449   /// See AbstractAttribute::trackStatistics().
2450   void trackStatistics() const override {}
2451 
2452   void initialize(Attributor &A) override {
2453     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2454     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2455 
2456     for (User *U : RFI.Declaration->users())
2457       if (CallBase *CB = dyn_cast<CallBase>(U))
2458         MallocCalls.insert(CB);
2459   }
2460 
2461   ChangeStatus manifest(Attributor &A) override {
2462     if (MallocCalls.empty())
2463       return ChangeStatus::UNCHANGED;
2464 
2465     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2466     auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2467 
2468     Function *F = getAnchorScope();
2469     auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
2470                                             DepClassTy::OPTIONAL);
2471 
2472     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2473     for (CallBase *CB : MallocCalls) {
2474       // Skip replacing this if HeapToStack has already claimed it.
2475       if (HS && HS->isKnownHeapToStack(*CB))
2476         continue;
2477 
2478       // Find the unique free call to remove it.
2479       SmallVector<CallBase *, 4> FreeCalls;
2480       for (auto *U : CB->users()) {
2481         CallBase *C = dyn_cast<CallBase>(U);
2482         if (C && C->getCalledFunction() == FreeCall.Declaration)
2483           FreeCalls.push_back(C);
2484       }
2485       if (FreeCalls.size() != 1)
2486         continue;
2487 
2488       ConstantInt *AllocSize = dyn_cast<ConstantInt>(CB->getArgOperand(0));
2489 
2490       LLVM_DEBUG(dbgs() << TAG << "Replace globalization call in "
2491                         << CB->getCaller()->getName() << " with "
2492                         << AllocSize->getZExtValue()
2493                         << " bytes of shared memory\n");
2494 
2495       // Create a new shared memory buffer of the same size as the allocation
2496       // and replace all the uses of the original allocation with it.
2497       Module *M = CB->getModule();
2498       Type *Int8Ty = Type::getInt8Ty(M->getContext());
2499       Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
2500       auto *SharedMem = new GlobalVariable(
2501           *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
2502           UndefValue::get(Int8ArrTy), CB->getName(), nullptr,
2503           GlobalValue::NotThreadLocal,
2504           static_cast<unsigned>(AddressSpace::Shared));
2505       auto *NewBuffer =
2506           ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
2507 
2508       auto Remark = [&](OptimizationRemark OR) {
2509         return OR << "Replaced globalized variable with "
2510                   << ore::NV("SharedMemory", AllocSize->getZExtValue())
2511                   << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ")
2512                   << "of shared memory";
2513       };
2514       A.emitRemark<OptimizationRemark>(CB, "OpenMPReplaceGlobalization",
2515                                        Remark);
2516 
2517       SharedMem->setAlignment(MaybeAlign(32));
2518 
2519       A.changeValueAfterManifest(*CB, *NewBuffer);
2520       A.deleteAfterManifest(*CB);
2521       A.deleteAfterManifest(*FreeCalls.front());
2522 
2523       NumBytesMovedToSharedMemory += AllocSize->getZExtValue();
2524       Changed = ChangeStatus::CHANGED;
2525     }
2526 
2527     return Changed;
2528   }
2529 
2530   ChangeStatus updateImpl(Attributor &A) override {
2531     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2532     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2533     Function *F = getAnchorScope();
2534 
2535     auto NumMallocCalls = MallocCalls.size();
2536 
2537     // Only consider malloc calls executed by a single thread with a constant.
2538     for (User *U : RFI.Declaration->users()) {
2539       const auto &ED = A.getAAFor<AAExecutionDomain>(
2540           *this, IRPosition::function(*F), DepClassTy::REQUIRED);
2541       if (CallBase *CB = dyn_cast<CallBase>(U))
2542         if (!dyn_cast<ConstantInt>(CB->getArgOperand(0)) ||
2543             !ED.isExecutedByInitialThreadOnly(*CB))
2544           MallocCalls.erase(CB);
2545     }
2546 
2547     if (NumMallocCalls != MallocCalls.size())
2548       return ChangeStatus::CHANGED;
2549 
2550     return ChangeStatus::UNCHANGED;
2551   }
2552 
2553   /// Collection of all malloc calls in a function.
2554   SmallPtrSet<CallBase *, 4> MallocCalls;
2555 };
2556 
2557 } // namespace
2558 
2559 const char AAICVTracker::ID = 0;
2560 const char AAExecutionDomain::ID = 0;
2561 const char AAHeapToShared::ID = 0;
2562 
2563 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
2564                                               Attributor &A) {
2565   AAICVTracker *AA = nullptr;
2566   switch (IRP.getPositionKind()) {
2567   case IRPosition::IRP_INVALID:
2568   case IRPosition::IRP_FLOAT:
2569   case IRPosition::IRP_ARGUMENT:
2570   case IRPosition::IRP_CALL_SITE_ARGUMENT:
2571     llvm_unreachable("ICVTracker can only be created for function position!");
2572   case IRPosition::IRP_RETURNED:
2573     AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
2574     break;
2575   case IRPosition::IRP_CALL_SITE_RETURNED:
2576     AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
2577     break;
2578   case IRPosition::IRP_CALL_SITE:
2579     AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
2580     break;
2581   case IRPosition::IRP_FUNCTION:
2582     AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
2583     break;
2584   }
2585 
2586   return *AA;
2587 }
2588 
2589 AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP,
2590                                                         Attributor &A) {
2591   AAExecutionDomainFunction *AA = nullptr;
2592   switch (IRP.getPositionKind()) {
2593   case IRPosition::IRP_INVALID:
2594   case IRPosition::IRP_FLOAT:
2595   case IRPosition::IRP_ARGUMENT:
2596   case IRPosition::IRP_CALL_SITE_ARGUMENT:
2597   case IRPosition::IRP_RETURNED:
2598   case IRPosition::IRP_CALL_SITE_RETURNED:
2599   case IRPosition::IRP_CALL_SITE:
2600     llvm_unreachable(
2601         "AAExecutionDomain can only be created for function position!");
2602   case IRPosition::IRP_FUNCTION:
2603     AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
2604     break;
2605   }
2606 
2607   return *AA;
2608 }
2609 
2610 AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
2611                                                   Attributor &A) {
2612   AAHeapToSharedFunction *AA = nullptr;
2613   switch (IRP.getPositionKind()) {
2614   case IRPosition::IRP_INVALID:
2615   case IRPosition::IRP_FLOAT:
2616   case IRPosition::IRP_ARGUMENT:
2617   case IRPosition::IRP_CALL_SITE_ARGUMENT:
2618   case IRPosition::IRP_RETURNED:
2619   case IRPosition::IRP_CALL_SITE_RETURNED:
2620   case IRPosition::IRP_CALL_SITE:
2621     llvm_unreachable(
2622         "AAHeapToShared can only be created for function position!");
2623   case IRPosition::IRP_FUNCTION:
2624     AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
2625     break;
2626   }
2627 
2628   return *AA;
2629 }
2630 
2631 PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
2632   if (!containsOpenMP(M, OMPInModule))
2633     return PreservedAnalyses::all();
2634 
2635   if (DisableOpenMPOptimizations)
2636     return PreservedAnalyses::all();
2637 
2638   // Create internal copies of each function if this is a kernel Module.
2639   DenseSet<const Function *> InternalizedFuncs;
2640   if (!OMPInModule.getKernels().empty())
2641     for (Function &F : M)
2642       if (!F.isDeclaration() && !OMPInModule.getKernels().contains(&F))
2643         if (Attributor::internalizeFunction(F, /* Force */ true))
2644           InternalizedFuncs.insert(&F);
2645 
2646   // Look at every function definition in the Module that wasn't internalized.
2647   SmallVector<Function *, 16> SCC;
2648   for (Function &F : M)
2649     if (!F.isDeclaration() && !InternalizedFuncs.contains(&F))
2650       SCC.push_back(&F);
2651 
2652   if (SCC.empty())
2653     return PreservedAnalyses::all();
2654 
2655   FunctionAnalysisManager &FAM =
2656       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
2657 
2658   AnalysisGetter AG(FAM);
2659 
2660   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
2661     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
2662   };
2663 
2664   BumpPtrAllocator Allocator;
2665   CallGraphUpdater CGUpdater;
2666 
2667   SetVector<Function *> Functions(SCC.begin(), SCC.end());
2668   OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions,
2669                                 OMPInModule.getKernels());
2670 
2671   unsigned MaxFixponitIterations = (!OMPInModule.getKernels().empty()) ? 64 : 32;
2672   Attributor A(Functions, InfoCache, CGUpdater, nullptr, true, false, MaxFixponitIterations, OREGetter,
2673                DEBUG_TYPE);
2674 
2675   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
2676   bool Changed = OMPOpt.run(true);
2677   if (Changed)
2678     return PreservedAnalyses::none();
2679 
2680   return PreservedAnalyses::all();
2681 }
2682 
2683 PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
2684                                           CGSCCAnalysisManager &AM,
2685                                           LazyCallGraph &CG,
2686                                           CGSCCUpdateResult &UR) {
2687   if (!containsOpenMP(*C.begin()->getFunction().getParent(), OMPInModule))
2688     return PreservedAnalyses::all();
2689 
2690   if (DisableOpenMPOptimizations)
2691     return PreservedAnalyses::all();
2692 
2693   SmallVector<Function *, 16> SCC;
2694   // If there are kernels in the module, we have to run on all SCC's.
2695   bool SCCIsInteresting = !OMPInModule.getKernels().empty();
2696   for (LazyCallGraph::Node &N : C) {
2697     Function *Fn = &N.getFunction();
2698     SCC.push_back(Fn);
2699 
2700     // Do we already know that the SCC contains kernels,
2701     // or that OpenMP functions are called from this SCC?
2702     if (SCCIsInteresting)
2703       continue;
2704     // If not, let's check that.
2705     SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn);
2706   }
2707 
2708   if (!SCCIsInteresting || SCC.empty())
2709     return PreservedAnalyses::all();
2710 
2711   FunctionAnalysisManager &FAM =
2712       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
2713 
2714   AnalysisGetter AG(FAM);
2715 
2716   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
2717     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
2718   };
2719 
2720   BumpPtrAllocator Allocator;
2721   CallGraphUpdater CGUpdater;
2722   CGUpdater.initialize(CG, C, AM, UR);
2723 
2724   SetVector<Function *> Functions(SCC.begin(), SCC.end());
2725   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
2726                                 /*CGSCC*/ Functions, OMPInModule.getKernels());
2727 
2728   unsigned MaxFixponitIterations = (!OMPInModule.getKernels().empty()) ? 64 : 32;
2729   Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true, MaxFixponitIterations, OREGetter,
2730                DEBUG_TYPE);
2731 
2732   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
2733   bool Changed = OMPOpt.run(false);
2734   if (Changed)
2735     return PreservedAnalyses::none();
2736 
2737   return PreservedAnalyses::all();
2738 }
2739 
2740 namespace {
2741 
2742 struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass {
2743   CallGraphUpdater CGUpdater;
2744   OpenMPInModule OMPInModule;
2745   static char ID;
2746 
2747   OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) {
2748     initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry());
2749   }
2750 
2751   void getAnalysisUsage(AnalysisUsage &AU) const override {
2752     CallGraphSCCPass::getAnalysisUsage(AU);
2753   }
2754 
2755   bool doInitialization(CallGraph &CG) override {
2756     // Disable the pass if there is no OpenMP (runtime call) in the module.
2757     containsOpenMP(CG.getModule(), OMPInModule);
2758     return false;
2759   }
2760 
2761   bool runOnSCC(CallGraphSCC &CGSCC) override {
2762     if (!containsOpenMP(CGSCC.getCallGraph().getModule(), OMPInModule))
2763       return false;
2764     if (DisableOpenMPOptimizations || skipSCC(CGSCC))
2765       return false;
2766 
2767     SmallVector<Function *, 16> SCC;
2768     // If there are kernels in the module, we have to run on all SCC's.
2769     bool SCCIsInteresting = !OMPInModule.getKernels().empty();
2770     for (CallGraphNode *CGN : CGSCC) {
2771       Function *Fn = CGN->getFunction();
2772       if (!Fn || Fn->isDeclaration())
2773         continue;
2774       SCC.push_back(Fn);
2775 
2776       // Do we already know that the SCC contains kernels,
2777       // or that OpenMP functions are called from this SCC?
2778       if (SCCIsInteresting)
2779         continue;
2780       // If not, let's check that.
2781       SCCIsInteresting |= OMPInModule.containsOMPRuntimeCalls(Fn);
2782     }
2783 
2784     if (!SCCIsInteresting || SCC.empty())
2785       return false;
2786 
2787     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
2788     CGUpdater.initialize(CG, CGSCC);
2789 
2790     // Maintain a map of functions to avoid rebuilding the ORE
2791     DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap;
2792     auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
2793       std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
2794       if (!ORE)
2795         ORE = std::make_unique<OptimizationRemarkEmitter>(F);
2796       return *ORE;
2797     };
2798 
2799     AnalysisGetter AG;
2800     SetVector<Function *> Functions(SCC.begin(), SCC.end());
2801     BumpPtrAllocator Allocator;
2802     OMPInformationCache InfoCache(
2803         *(Functions.back()->getParent()), AG, Allocator,
2804         /*CGSCC*/ Functions, OMPInModule.getKernels());
2805 
2806     unsigned MaxFixponitIterations = (!OMPInModule.getKernels().empty()) ? 64 : 32;
2807     Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
2808                  MaxFixponitIterations, OREGetter, DEBUG_TYPE);
2809 
2810     OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
2811     return OMPOpt.run(false);
2812   }
2813 
2814   bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
2815 };
2816 
2817 } // end anonymous namespace
2818 
2819 void OpenMPInModule::identifyKernels(Module &M) {
2820 
2821   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
2822   if (!MD)
2823     return;
2824 
2825   for (auto *Op : MD->operands()) {
2826     if (Op->getNumOperands() < 2)
2827       continue;
2828     MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
2829     if (!KindID || KindID->getString() != "kernel")
2830       continue;
2831 
2832     Function *KernelFn =
2833         mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
2834     if (!KernelFn)
2835       continue;
2836 
2837     ++NumOpenMPTargetRegionKernels;
2838 
2839     Kernels.insert(KernelFn);
2840   }
2841 }
2842 
2843 bool llvm::omp::containsOpenMP(Module &M, OpenMPInModule &OMPInModule) {
2844   if (OMPInModule.isKnown())
2845     return OMPInModule;
2846 
2847   auto RecordFunctionsContainingUsesOf = [&](Function *F) {
2848     for (User *U : F->users())
2849       if (auto *I = dyn_cast<Instruction>(U))
2850         OMPInModule.FuncsWithOMPRuntimeCalls.insert(I->getFunction());
2851   };
2852 
2853   // MSVC doesn't like long if-else chains for some reason and instead just
2854   // issues an error. Work around it..
2855   do {
2856 #define OMP_RTL(_Enum, _Name, ...)                                             \
2857   if (Function *F = M.getFunction(_Name)) {                                    \
2858     RecordFunctionsContainingUsesOf(F);                                        \
2859     OMPInModule = true;                                                        \
2860   }
2861 #include "llvm/Frontend/OpenMP/OMPKinds.def"
2862   } while (false);
2863 
2864   // Identify kernels once. TODO: We should split the OMPInformationCache into a
2865   // module and an SCC part. The kernel information, among other things, could
2866   // go into the module part.
2867   if (OMPInModule.isKnown() && OMPInModule) {
2868     OMPInModule.identifyKernels(M);
2869     return true;
2870   }
2871 
2872   return OMPInModule = false;
2873 }
2874 
2875 char OpenMPOptCGSCCLegacyPass::ID = 0;
2876 
2877 INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
2878                       "OpenMP specific optimizations", false, false)
2879 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
2880 INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
2881                     "OpenMP specific optimizations", false, false)
2882 
2883 Pass *llvm::createOpenMPOptCGSCCLegacyPass() {
2884   return new OpenMPOptCGSCCLegacyPass();
2885 }
2886