1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpenMP specific optimizations:
10 //
11 // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12 // - Replacing globalized device memory with stack memory.
13 // - Replacing globalized device memory with shared memory.
14 // - Parallel region merging.
15 // - Transforming generic-mode device kernels to SPMD mode.
16 // - Specializing the state machine for generic-mode device kernels.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "llvm/Transforms/IPO/OpenMPOpt.h"
21 
22 #include "llvm/ADT/EnumeratedArray.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/Statistic.h"
25 #include "llvm/Analysis/CallGraph.h"
26 #include "llvm/Analysis/CallGraphSCCPass.h"
27 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
28 #include "llvm/Analysis/ValueTracking.h"
29 #include "llvm/Frontend/OpenMP/OMPConstants.h"
30 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
31 #include "llvm/IR/Assumptions.h"
32 #include "llvm/IR/DiagnosticInfo.h"
33 #include "llvm/IR/GlobalValue.h"
34 #include "llvm/IR/Instruction.h"
35 #include "llvm/IR/IntrinsicInst.h"
36 #include "llvm/InitializePasses.h"
37 #include "llvm/Support/CommandLine.h"
38 #include "llvm/Transforms/IPO.h"
39 #include "llvm/Transforms/IPO/Attributor.h"
40 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
41 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
42 #include "llvm/Transforms/Utils/CodeExtractor.h"
43 
44 using namespace llvm;
45 using namespace omp;
46 
47 #define DEBUG_TYPE "openmp-opt"
48 
49 static cl::opt<bool> DisableOpenMPOptimizations(
50     "openmp-opt-disable", cl::ZeroOrMore,
51     cl::desc("Disable OpenMP specific optimizations."), cl::Hidden,
52     cl::init(false));
53 
54 static cl::opt<bool> EnableParallelRegionMerging(
55     "openmp-opt-enable-merging", cl::ZeroOrMore,
56     cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
57     cl::init(false));
58 
59 static cl::opt<bool>
60     DisableInternalization("openmp-opt-disable-internalization", cl::ZeroOrMore,
61                            cl::desc("Disable function internalization."),
62                            cl::Hidden, cl::init(false));
63 
64 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
65                                     cl::Hidden);
66 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
67                                         cl::init(false), cl::Hidden);
68 
69 static cl::opt<bool> HideMemoryTransferLatency(
70     "openmp-hide-memory-transfer-latency",
71     cl::desc("[WIP] Tries to hide the latency of host to device memory"
72              " transfers"),
73     cl::Hidden, cl::init(false));
74 
75 static cl::opt<bool> DisableOpenMPOptDeglobalization(
76     "openmp-opt-disable-deglobalization", cl::ZeroOrMore,
77     cl::desc("Disable OpenMP optimizations involving deglobalization."),
78     cl::Hidden, cl::init(false));
79 
80 static cl::opt<bool> DisableOpenMPOptSPMDization(
81     "openmp-opt-disable-spmdization", cl::ZeroOrMore,
82     cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
83     cl::Hidden, cl::init(false));
84 
85 static cl::opt<bool> DisableOpenMPOptFolding(
86     "openmp-opt-disable-folding", cl::ZeroOrMore,
87     cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
88     cl::init(false));
89 
90 static cl::opt<bool> DisableOpenMPOptStateMachineRewrite(
91     "openmp-opt-disable-state-machine-rewrite", cl::ZeroOrMore,
92     cl::desc("Disable OpenMP optimizations that replace the state machine."),
93     cl::Hidden, cl::init(false));
94 
95 static cl::opt<bool> PrintModuleAfterOptimizations(
96     "openmp-opt-print-module", cl::ZeroOrMore,
97     cl::desc("Print the current module after OpenMP optimizations."),
98     cl::Hidden, cl::init(false));
99 
100 static cl::opt<bool> AlwaysInlineDeviceFunctions(
101     "openmp-opt-inline-device", cl::ZeroOrMore,
102     cl::desc("Inline all applicible functions on the device."), cl::Hidden,
103     cl::init(false));
104 
105 STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
106           "Number of OpenMP runtime calls deduplicated");
107 STATISTIC(NumOpenMPParallelRegionsDeleted,
108           "Number of OpenMP parallel regions deleted");
109 STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
110           "Number of OpenMP runtime functions identified");
111 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
112           "Number of OpenMP runtime function uses identified");
113 STATISTIC(NumOpenMPTargetRegionKernels,
114           "Number of OpenMP target region entry points (=kernels) identified");
115 STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
116           "Number of OpenMP target region entry points (=kernels) executed in "
117           "SPMD-mode instead of generic-mode");
118 STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
119           "Number of OpenMP target region entry points (=kernels) executed in "
120           "generic-mode without a state machines");
121 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
122           "Number of OpenMP target region entry points (=kernels) executed in "
123           "generic-mode with customized state machines with fallback");
124 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
125           "Number of OpenMP target region entry points (=kernels) executed in "
126           "generic-mode with customized state machines without fallback");
127 STATISTIC(
128     NumOpenMPParallelRegionsReplacedInGPUStateMachine,
129     "Number of OpenMP parallel regions replaced with ID in GPU state machines");
130 STATISTIC(NumOpenMPParallelRegionsMerged,
131           "Number of OpenMP parallel regions merged");
132 STATISTIC(NumBytesMovedToSharedMemory,
133           "Amount of memory pushed to shared memory");
134 
135 #if !defined(NDEBUG)
136 static constexpr auto TAG = "[" DEBUG_TYPE "]";
137 #endif
138 
139 namespace {
140 
141 enum class AddressSpace : unsigned {
142   Generic = 0,
143   Global = 1,
144   Shared = 3,
145   Constant = 4,
146   Local = 5,
147 };
148 
149 struct AAHeapToShared;
150 
151 struct AAICVTracker;
152 
153 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
154 /// Attributor runs.
155 struct OMPInformationCache : public InformationCache {
156   OMPInformationCache(Module &M, AnalysisGetter &AG,
157                       BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC,
158                       SmallPtrSetImpl<Kernel> &Kernels)
159       : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M),
160         Kernels(Kernels) {
161 
162     OMPBuilder.initialize();
163     initializeRuntimeFunctions();
164     initializeInternalControlVars();
165   }
166 
167   /// Generic information that describes an internal control variable.
168   struct InternalControlVarInfo {
169     /// The kind, as described by InternalControlVar enum.
170     InternalControlVar Kind;
171 
172     /// The name of the ICV.
173     StringRef Name;
174 
175     /// Environment variable associated with this ICV.
176     StringRef EnvVarName;
177 
178     /// Initial value kind.
179     ICVInitValue InitKind;
180 
181     /// Initial value.
182     ConstantInt *InitValue;
183 
184     /// Setter RTL function associated with this ICV.
185     RuntimeFunction Setter;
186 
187     /// Getter RTL function associated with this ICV.
188     RuntimeFunction Getter;
189 
190     /// RTL Function corresponding to the override clause of this ICV
191     RuntimeFunction Clause;
192   };
193 
194   /// Generic information that describes a runtime function
195   struct RuntimeFunctionInfo {
196 
197     /// The kind, as described by the RuntimeFunction enum.
198     RuntimeFunction Kind;
199 
200     /// The name of the function.
201     StringRef Name;
202 
203     /// Flag to indicate a variadic function.
204     bool IsVarArg;
205 
206     /// The return type of the function.
207     Type *ReturnType;
208 
209     /// The argument types of the function.
210     SmallVector<Type *, 8> ArgumentTypes;
211 
212     /// The declaration if available.
213     Function *Declaration = nullptr;
214 
215     /// Uses of this runtime function per function containing the use.
216     using UseVector = SmallVector<Use *, 16>;
217 
218     /// Clear UsesMap for runtime function.
219     void clearUsesMap() { UsesMap.clear(); }
220 
221     /// Boolean conversion that is true if the runtime function was found.
222     operator bool() const { return Declaration; }
223 
224     /// Return the vector of uses in function \p F.
225     UseVector &getOrCreateUseVector(Function *F) {
226       std::shared_ptr<UseVector> &UV = UsesMap[F];
227       if (!UV)
228         UV = std::make_shared<UseVector>();
229       return *UV;
230     }
231 
232     /// Return the vector of uses in function \p F or `nullptr` if there are
233     /// none.
234     const UseVector *getUseVector(Function &F) const {
235       auto I = UsesMap.find(&F);
236       if (I != UsesMap.end())
237         return I->second.get();
238       return nullptr;
239     }
240 
241     /// Return how many functions contain uses of this runtime function.
242     size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
243 
244     /// Return the number of arguments (or the minimal number for variadic
245     /// functions).
246     size_t getNumArgs() const { return ArgumentTypes.size(); }
247 
248     /// Run the callback \p CB on each use and forget the use if the result is
249     /// true. The callback will be fed the function in which the use was
250     /// encountered as second argument.
251     void foreachUse(SmallVectorImpl<Function *> &SCC,
252                     function_ref<bool(Use &, Function &)> CB) {
253       for (Function *F : SCC)
254         foreachUse(CB, F);
255     }
256 
257     /// Run the callback \p CB on each use within the function \p F and forget
258     /// the use if the result is true.
259     void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
260       SmallVector<unsigned, 8> ToBeDeleted;
261       ToBeDeleted.clear();
262 
263       unsigned Idx = 0;
264       UseVector &UV = getOrCreateUseVector(F);
265 
266       for (Use *U : UV) {
267         if (CB(*U, *F))
268           ToBeDeleted.push_back(Idx);
269         ++Idx;
270       }
271 
272       // Remove the to-be-deleted indices in reverse order as prior
273       // modifications will not modify the smaller indices.
274       while (!ToBeDeleted.empty()) {
275         unsigned Idx = ToBeDeleted.pop_back_val();
276         UV[Idx] = UV.back();
277         UV.pop_back();
278       }
279     }
280 
281   private:
282     /// Map from functions to all uses of this runtime function contained in
283     /// them.
284     DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
285 
286   public:
287     /// Iterators for the uses of this runtime function.
288     decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
289     decltype(UsesMap)::iterator end() { return UsesMap.end(); }
290   };
291 
292   /// An OpenMP-IR-Builder instance
293   OpenMPIRBuilder OMPBuilder;
294 
295   /// Map from runtime function kind to the runtime function description.
296   EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
297                   RuntimeFunction::OMPRTL___last>
298       RFIs;
299 
300   /// Map from function declarations/definitions to their runtime enum type.
301   DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
302 
303   /// Map from ICV kind to the ICV description.
304   EnumeratedArray<InternalControlVarInfo, InternalControlVar,
305                   InternalControlVar::ICV___last>
306       ICVs;
307 
308   /// Helper to initialize all internal control variable information for those
309   /// defined in OMPKinds.def.
310   void initializeInternalControlVars() {
311 #define ICV_RT_SET(_Name, RTL)                                                 \
312   {                                                                            \
313     auto &ICV = ICVs[_Name];                                                   \
314     ICV.Setter = RTL;                                                          \
315   }
316 #define ICV_RT_GET(Name, RTL)                                                  \
317   {                                                                            \
318     auto &ICV = ICVs[Name];                                                    \
319     ICV.Getter = RTL;                                                          \
320   }
321 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init)                           \
322   {                                                                            \
323     auto &ICV = ICVs[Enum];                                                    \
324     ICV.Name = _Name;                                                          \
325     ICV.Kind = Enum;                                                           \
326     ICV.InitKind = Init;                                                       \
327     ICV.EnvVarName = _EnvVarName;                                              \
328     switch (ICV.InitKind) {                                                    \
329     case ICV_IMPLEMENTATION_DEFINED:                                           \
330       ICV.InitValue = nullptr;                                                 \
331       break;                                                                   \
332     case ICV_ZERO:                                                             \
333       ICV.InitValue = ConstantInt::get(                                        \
334           Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0);                \
335       break;                                                                   \
336     case ICV_FALSE:                                                            \
337       ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext());    \
338       break;                                                                   \
339     case ICV_LAST:                                                             \
340       break;                                                                   \
341     }                                                                          \
342   }
343 #include "llvm/Frontend/OpenMP/OMPKinds.def"
344   }
345 
346   /// Returns true if the function declaration \p F matches the runtime
347   /// function types, that is, return type \p RTFRetType, and argument types
348   /// \p RTFArgTypes.
349   static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
350                                   SmallVector<Type *, 8> &RTFArgTypes) {
351     // TODO: We should output information to the user (under debug output
352     //       and via remarks).
353 
354     if (!F)
355       return false;
356     if (F->getReturnType() != RTFRetType)
357       return false;
358     if (F->arg_size() != RTFArgTypes.size())
359       return false;
360 
361     auto RTFTyIt = RTFArgTypes.begin();
362     for (Argument &Arg : F->args()) {
363       if (Arg.getType() != *RTFTyIt)
364         return false;
365 
366       ++RTFTyIt;
367     }
368 
369     return true;
370   }
371 
372   // Helper to collect all uses of the declaration in the UsesMap.
373   unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
374     unsigned NumUses = 0;
375     if (!RFI.Declaration)
376       return NumUses;
377     OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
378 
379     if (CollectStats) {
380       NumOpenMPRuntimeFunctionsIdentified += 1;
381       NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
382     }
383 
384     // TODO: We directly convert uses into proper calls and unknown uses.
385     for (Use &U : RFI.Declaration->uses()) {
386       if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
387         if (ModuleSlice.count(UserI->getFunction())) {
388           RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
389           ++NumUses;
390         }
391       } else {
392         RFI.getOrCreateUseVector(nullptr).push_back(&U);
393         ++NumUses;
394       }
395     }
396     return NumUses;
397   }
398 
399   // Helper function to recollect uses of a runtime function.
400   void recollectUsesForFunction(RuntimeFunction RTF) {
401     auto &RFI = RFIs[RTF];
402     RFI.clearUsesMap();
403     collectUses(RFI, /*CollectStats*/ false);
404   }
405 
406   // Helper function to recollect uses of all runtime functions.
407   void recollectUses() {
408     for (int Idx = 0; Idx < RFIs.size(); ++Idx)
409       recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
410   }
411 
412   /// Helper to initialize all runtime function information for those defined
413   /// in OpenMPKinds.def.
414   void initializeRuntimeFunctions() {
415     Module &M = *((*ModuleSlice.begin())->getParent());
416 
417     // Helper macros for handling __VA_ARGS__ in OMP_RTL
418 #define OMP_TYPE(VarName, ...)                                                 \
419   Type *VarName = OMPBuilder.VarName;                                          \
420   (void)VarName;
421 
422 #define OMP_ARRAY_TYPE(VarName, ...)                                           \
423   ArrayType *VarName##Ty = OMPBuilder.VarName##Ty;                             \
424   (void)VarName##Ty;                                                           \
425   PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy;                     \
426   (void)VarName##PtrTy;
427 
428 #define OMP_FUNCTION_TYPE(VarName, ...)                                        \
429   FunctionType *VarName = OMPBuilder.VarName;                                  \
430   (void)VarName;                                                               \
431   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
432   (void)VarName##Ptr;
433 
434 #define OMP_STRUCT_TYPE(VarName, ...)                                          \
435   StructType *VarName = OMPBuilder.VarName;                                    \
436   (void)VarName;                                                               \
437   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
438   (void)VarName##Ptr;
439 
440 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
441   {                                                                            \
442     SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
443     Function *F = M.getFunction(_Name);                                        \
444     RTLFunctions.insert(F);                                                    \
445     if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) {           \
446       RuntimeFunctionIDMap[F] = _Enum;                                         \
447       F->removeFnAttr(Attribute::NoInline);                                    \
448       auto &RFI = RFIs[_Enum];                                                 \
449       RFI.Kind = _Enum;                                                        \
450       RFI.Name = _Name;                                                        \
451       RFI.IsVarArg = _IsVarArg;                                                \
452       RFI.ReturnType = OMPBuilder._ReturnType;                                 \
453       RFI.ArgumentTypes = std::move(ArgsTypes);                                \
454       RFI.Declaration = F;                                                     \
455       unsigned NumUses = collectUses(RFI);                                     \
456       (void)NumUses;                                                           \
457       LLVM_DEBUG({                                                             \
458         dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
459                << " found\n";                                                  \
460         if (RFI.Declaration)                                                   \
461           dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
462                  << RFI.getNumFunctionsWithUses()                              \
463                  << " different functions.\n";                                 \
464       });                                                                      \
465     }                                                                          \
466   }
467 #include "llvm/Frontend/OpenMP/OMPKinds.def"
468 
469     // TODO: We should attach the attributes defined in OMPKinds.def.
470   }
471 
472   /// Collection of known kernels (\see Kernel) in the module.
473   SmallPtrSetImpl<Kernel> &Kernels;
474 
475   /// Collection of known OpenMP runtime functions..
476   DenseSet<const Function *> RTLFunctions;
477 };
478 
479 template <typename Ty, bool InsertInvalidates = true>
480 struct BooleanStateWithSetVector : public BooleanState {
481   bool contains(const Ty &Elem) const { return Set.contains(Elem); }
482   bool insert(const Ty &Elem) {
483     if (InsertInvalidates)
484       BooleanState::indicatePessimisticFixpoint();
485     return Set.insert(Elem);
486   }
487 
488   const Ty &operator[](int Idx) const { return Set[Idx]; }
489   bool operator==(const BooleanStateWithSetVector &RHS) const {
490     return BooleanState::operator==(RHS) && Set == RHS.Set;
491   }
492   bool operator!=(const BooleanStateWithSetVector &RHS) const {
493     return !(*this == RHS);
494   }
495 
496   bool empty() const { return Set.empty(); }
497   size_t size() const { return Set.size(); }
498 
499   /// "Clamp" this state with \p RHS.
500   BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
501     BooleanState::operator^=(RHS);
502     Set.insert(RHS.Set.begin(), RHS.Set.end());
503     return *this;
504   }
505 
506 private:
507   /// A set to keep track of elements.
508   SetVector<Ty> Set;
509 
510 public:
511   typename decltype(Set)::iterator begin() { return Set.begin(); }
512   typename decltype(Set)::iterator end() { return Set.end(); }
513   typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
514   typename decltype(Set)::const_iterator end() const { return Set.end(); }
515 };
516 
517 template <typename Ty, bool InsertInvalidates = true>
518 using BooleanStateWithPtrSetVector =
519     BooleanStateWithSetVector<Ty *, InsertInvalidates>;
520 
521 struct KernelInfoState : AbstractState {
522   /// Flag to track if we reached a fixpoint.
523   bool IsAtFixpoint = false;
524 
525   /// The parallel regions (identified by the outlined parallel functions) that
526   /// can be reached from the associated function.
527   BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false>
528       ReachedKnownParallelRegions;
529 
530   /// State to track what parallel region we might reach.
531   BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
532 
533   /// State to track if we are in SPMD-mode, assumed or know, and why we decided
534   /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
535   /// false.
536   BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
537 
538   /// The __kmpc_target_init call in this kernel, if any. If we find more than
539   /// one we abort as the kernel is malformed.
540   CallBase *KernelInitCB = nullptr;
541 
542   /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
543   /// one we abort as the kernel is malformed.
544   CallBase *KernelDeinitCB = nullptr;
545 
546   /// Flag to indicate if the associated function is a kernel entry.
547   bool IsKernelEntry = false;
548 
549   /// State to track what kernel entries can reach the associated function.
550   BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
551 
552   /// State to indicate if we can track parallel level of the associated
553   /// function. We will give up tracking if we encounter unknown caller or the
554   /// caller is __kmpc_parallel_51.
555   BooleanStateWithSetVector<uint8_t> ParallelLevels;
556 
557   /// Abstract State interface
558   ///{
559 
560   KernelInfoState() {}
561   KernelInfoState(bool BestState) {
562     if (!BestState)
563       indicatePessimisticFixpoint();
564   }
565 
566   /// See AbstractState::isValidState(...)
567   bool isValidState() const override { return true; }
568 
569   /// See AbstractState::isAtFixpoint(...)
570   bool isAtFixpoint() const override { return IsAtFixpoint; }
571 
572   /// See AbstractState::indicatePessimisticFixpoint(...)
573   ChangeStatus indicatePessimisticFixpoint() override {
574     IsAtFixpoint = true;
575     SPMDCompatibilityTracker.indicatePessimisticFixpoint();
576     ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
577     return ChangeStatus::CHANGED;
578   }
579 
580   /// See AbstractState::indicateOptimisticFixpoint(...)
581   ChangeStatus indicateOptimisticFixpoint() override {
582     IsAtFixpoint = true;
583     return ChangeStatus::UNCHANGED;
584   }
585 
586   /// Return the assumed state
587   KernelInfoState &getAssumed() { return *this; }
588   const KernelInfoState &getAssumed() const { return *this; }
589 
590   bool operator==(const KernelInfoState &RHS) const {
591     if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
592       return false;
593     if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
594       return false;
595     if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
596       return false;
597     if (ReachingKernelEntries != RHS.ReachingKernelEntries)
598       return false;
599     return true;
600   }
601 
602   /// Returns true if this kernel contains any OpenMP parallel regions.
603   bool mayContainParallelRegion() {
604     return !ReachedKnownParallelRegions.empty() ||
605            !ReachedUnknownParallelRegions.empty();
606   }
607 
608   /// Return empty set as the best state of potential values.
609   static KernelInfoState getBestState() { return KernelInfoState(true); }
610 
611   static KernelInfoState getBestState(KernelInfoState &KIS) {
612     return getBestState();
613   }
614 
615   /// Return full set as the worst state of potential values.
616   static KernelInfoState getWorstState() { return KernelInfoState(false); }
617 
618   /// "Clamp" this state with \p KIS.
619   KernelInfoState operator^=(const KernelInfoState &KIS) {
620     // Do not merge two different _init and _deinit call sites.
621     if (KIS.KernelInitCB) {
622       if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
623         indicatePessimisticFixpoint();
624       KernelInitCB = KIS.KernelInitCB;
625     }
626     if (KIS.KernelDeinitCB) {
627       if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
628         indicatePessimisticFixpoint();
629       KernelDeinitCB = KIS.KernelDeinitCB;
630     }
631     SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
632     ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
633     ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
634     return *this;
635   }
636 
637   KernelInfoState operator&=(const KernelInfoState &KIS) {
638     return (*this ^= KIS);
639   }
640 
641   ///}
642 };
643 
644 /// Used to map the values physically (in the IR) stored in an offload
645 /// array, to a vector in memory.
646 struct OffloadArray {
647   /// Physical array (in the IR).
648   AllocaInst *Array = nullptr;
649   /// Mapped values.
650   SmallVector<Value *, 8> StoredValues;
651   /// Last stores made in the offload array.
652   SmallVector<StoreInst *, 8> LastAccesses;
653 
654   OffloadArray() = default;
655 
656   /// Initializes the OffloadArray with the values stored in \p Array before
657   /// instruction \p Before is reached. Returns false if the initialization
658   /// fails.
659   /// This MUST be used immediately after the construction of the object.
660   bool initialize(AllocaInst &Array, Instruction &Before) {
661     if (!Array.getAllocatedType()->isArrayTy())
662       return false;
663 
664     if (!getValues(Array, Before))
665       return false;
666 
667     this->Array = &Array;
668     return true;
669   }
670 
671   static const unsigned DeviceIDArgNum = 1;
672   static const unsigned BasePtrsArgNum = 3;
673   static const unsigned PtrsArgNum = 4;
674   static const unsigned SizesArgNum = 5;
675 
676 private:
677   /// Traverses the BasicBlock where \p Array is, collecting the stores made to
678   /// \p Array, leaving StoredValues with the values stored before the
679   /// instruction \p Before is reached.
680   bool getValues(AllocaInst &Array, Instruction &Before) {
681     // Initialize container.
682     const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
683     StoredValues.assign(NumValues, nullptr);
684     LastAccesses.assign(NumValues, nullptr);
685 
686     // TODO: This assumes the instruction \p Before is in the same
687     //  BasicBlock as Array. Make it general, for any control flow graph.
688     BasicBlock *BB = Array.getParent();
689     if (BB != Before.getParent())
690       return false;
691 
692     const DataLayout &DL = Array.getModule()->getDataLayout();
693     const unsigned int PointerSize = DL.getPointerSize();
694 
695     for (Instruction &I : *BB) {
696       if (&I == &Before)
697         break;
698 
699       if (!isa<StoreInst>(&I))
700         continue;
701 
702       auto *S = cast<StoreInst>(&I);
703       int64_t Offset = -1;
704       auto *Dst =
705           GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
706       if (Dst == &Array) {
707         int64_t Idx = Offset / PointerSize;
708         StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
709         LastAccesses[Idx] = S;
710       }
711     }
712 
713     return isFilled();
714   }
715 
716   /// Returns true if all values in StoredValues and
717   /// LastAccesses are not nullptrs.
718   bool isFilled() {
719     const unsigned NumValues = StoredValues.size();
720     for (unsigned I = 0; I < NumValues; ++I) {
721       if (!StoredValues[I] || !LastAccesses[I])
722         return false;
723     }
724 
725     return true;
726   }
727 };
728 
729 struct OpenMPOpt {
730 
731   using OptimizationRemarkGetter =
732       function_ref<OptimizationRemarkEmitter &(Function *)>;
733 
734   OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
735             OptimizationRemarkGetter OREGetter,
736             OMPInformationCache &OMPInfoCache, Attributor &A)
737       : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
738         OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
739 
740   /// Check if any remarks are enabled for openmp-opt
741   bool remarksEnabled() {
742     auto &Ctx = M.getContext();
743     return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
744   }
745 
746   /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
747   bool run(bool IsModulePass) {
748     if (SCC.empty())
749       return false;
750 
751     bool Changed = false;
752 
753     LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
754                       << " functions in a slice with "
755                       << OMPInfoCache.ModuleSlice.size() << " functions\n");
756 
757     if (IsModulePass) {
758       Changed |= runAttributor(IsModulePass);
759 
760       // Recollect uses, in case Attributor deleted any.
761       OMPInfoCache.recollectUses();
762 
763       // TODO: This should be folded into buildCustomStateMachine.
764       Changed |= rewriteDeviceCodeStateMachine();
765 
766       if (remarksEnabled())
767         analysisGlobalization();
768     } else {
769       if (PrintICVValues)
770         printICVs();
771       if (PrintOpenMPKernels)
772         printKernels();
773 
774       Changed |= runAttributor(IsModulePass);
775 
776       // Recollect uses, in case Attributor deleted any.
777       OMPInfoCache.recollectUses();
778 
779       Changed |= deleteParallelRegions();
780 
781       if (HideMemoryTransferLatency)
782         Changed |= hideMemTransfersLatency();
783       Changed |= deduplicateRuntimeCalls();
784       if (EnableParallelRegionMerging) {
785         if (mergeParallelRegions()) {
786           deduplicateRuntimeCalls();
787           Changed = true;
788         }
789       }
790     }
791 
792     return Changed;
793   }
794 
795   /// Print initial ICV values for testing.
796   /// FIXME: This should be done from the Attributor once it is added.
797   void printICVs() const {
798     InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
799                                  ICV_proc_bind};
800 
801     for (Function *F : OMPInfoCache.ModuleSlice) {
802       for (auto ICV : ICVs) {
803         auto ICVInfo = OMPInfoCache.ICVs[ICV];
804         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
805           return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
806                      << " Value: "
807                      << (ICVInfo.InitValue
808                              ? toString(ICVInfo.InitValue->getValue(), 10, true)
809                              : "IMPLEMENTATION_DEFINED");
810         };
811 
812         emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
813       }
814     }
815   }
816 
817   /// Print OpenMP GPU kernels for testing.
818   void printKernels() const {
819     for (Function *F : SCC) {
820       if (!OMPInfoCache.Kernels.count(F))
821         continue;
822 
823       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
824         return ORA << "OpenMP GPU kernel "
825                    << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
826       };
827 
828       emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
829     }
830   }
831 
832   /// Return the call if \p U is a callee use in a regular call. If \p RFI is
833   /// given it has to be the callee or a nullptr is returned.
834   static CallInst *getCallIfRegularCall(
835       Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
836     CallInst *CI = dyn_cast<CallInst>(U.getUser());
837     if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
838         (!RFI ||
839          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
840       return CI;
841     return nullptr;
842   }
843 
844   /// Return the call if \p V is a regular call. If \p RFI is given it has to be
845   /// the callee or a nullptr is returned.
846   static CallInst *getCallIfRegularCall(
847       Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
848     CallInst *CI = dyn_cast<CallInst>(&V);
849     if (CI && !CI->hasOperandBundles() &&
850         (!RFI ||
851          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
852       return CI;
853     return nullptr;
854   }
855 
856 private:
857   /// Merge parallel regions when it is safe.
858   bool mergeParallelRegions() {
859     const unsigned CallbackCalleeOperand = 2;
860     const unsigned CallbackFirstArgOperand = 3;
861     using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
862 
863     // Check if there are any __kmpc_fork_call calls to merge.
864     OMPInformationCache::RuntimeFunctionInfo &RFI =
865         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
866 
867     if (!RFI.Declaration)
868       return false;
869 
870     // Unmergable calls that prevent merging a parallel region.
871     OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
872         OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
873         OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
874     };
875 
876     bool Changed = false;
877     LoopInfo *LI = nullptr;
878     DominatorTree *DT = nullptr;
879 
880     SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
881 
882     BasicBlock *StartBB = nullptr, *EndBB = nullptr;
883     auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
884                          BasicBlock &ContinuationIP) {
885       BasicBlock *CGStartBB = CodeGenIP.getBlock();
886       BasicBlock *CGEndBB =
887           SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
888       assert(StartBB != nullptr && "StartBB should not be null");
889       CGStartBB->getTerminator()->setSuccessor(0, StartBB);
890       assert(EndBB != nullptr && "EndBB should not be null");
891       EndBB->getTerminator()->setSuccessor(0, CGEndBB);
892     };
893 
894     auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
895                       Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
896       ReplacementValue = &Inner;
897       return CodeGenIP;
898     };
899 
900     auto FiniCB = [&](InsertPointTy CodeGenIP) {};
901 
902     /// Create a sequential execution region within a merged parallel region,
903     /// encapsulated in a master construct with a barrier for synchronization.
904     auto CreateSequentialRegion = [&](Function *OuterFn,
905                                       BasicBlock *OuterPredBB,
906                                       Instruction *SeqStartI,
907                                       Instruction *SeqEndI) {
908       // Isolate the instructions of the sequential region to a separate
909       // block.
910       BasicBlock *ParentBB = SeqStartI->getParent();
911       BasicBlock *SeqEndBB =
912           SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
913       BasicBlock *SeqAfterBB =
914           SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
915       BasicBlock *SeqStartBB =
916           SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
917 
918       assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
919              "Expected a different CFG");
920       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
921       ParentBB->getTerminator()->eraseFromParent();
922 
923       auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
924                            BasicBlock &ContinuationIP) {
925         BasicBlock *CGStartBB = CodeGenIP.getBlock();
926         BasicBlock *CGEndBB =
927             SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
928         assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
929         CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
930         assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
931         SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
932       };
933       auto FiniCB = [&](InsertPointTy CodeGenIP) {};
934 
935       // Find outputs from the sequential region to outside users and
936       // broadcast their values to them.
937       for (Instruction &I : *SeqStartBB) {
938         SmallPtrSet<Instruction *, 4> OutsideUsers;
939         for (User *Usr : I.users()) {
940           Instruction &UsrI = *cast<Instruction>(Usr);
941           // Ignore outputs to LT intrinsics, code extraction for the merged
942           // parallel region will fix them.
943           if (UsrI.isLifetimeStartOrEnd())
944             continue;
945 
946           if (UsrI.getParent() != SeqStartBB)
947             OutsideUsers.insert(&UsrI);
948         }
949 
950         if (OutsideUsers.empty())
951           continue;
952 
953         // Emit an alloca in the outer region to store the broadcasted
954         // value.
955         const DataLayout &DL = M.getDataLayout();
956         AllocaInst *AllocaI = new AllocaInst(
957             I.getType(), DL.getAllocaAddrSpace(), nullptr,
958             I.getName() + ".seq.output.alloc", &OuterFn->front().front());
959 
960         // Emit a store instruction in the sequential BB to update the
961         // value.
962         new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
963 
964         // Emit a load instruction and replace the use of the output value
965         // with it.
966         for (Instruction *UsrI : OutsideUsers) {
967           LoadInst *LoadI = new LoadInst(
968               I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
969           UsrI->replaceUsesOfWith(&I, LoadI);
970         }
971       }
972 
973       OpenMPIRBuilder::LocationDescription Loc(
974           InsertPointTy(ParentBB, ParentBB->end()), DL);
975       InsertPointTy SeqAfterIP =
976           OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
977 
978       OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
979 
980       BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
981 
982       LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
983                         << "\n");
984     };
985 
986     // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
987     // contained in BB and only separated by instructions that can be
988     // redundantly executed in parallel. The block BB is split before the first
989     // call (in MergableCIs) and after the last so the entire region we merge
990     // into a single parallel region is contained in a single basic block
991     // without any other instructions. We use the OpenMPIRBuilder to outline
992     // that block and call the resulting function via __kmpc_fork_call.
993     auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) {
994       // TODO: Change the interface to allow single CIs expanded, e.g, to
995       // include an outer loop.
996       assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
997 
998       auto Remark = [&](OptimizationRemark OR) {
999         OR << "Parallel region merged with parallel region"
1000            << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1001         for (auto *CI : llvm::drop_begin(MergableCIs)) {
1002           OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1003           if (CI != MergableCIs.back())
1004             OR << ", ";
1005         }
1006         return OR << ".";
1007       };
1008 
1009       emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1010 
1011       Function *OriginalFn = BB->getParent();
1012       LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1013                         << " parallel regions in " << OriginalFn->getName()
1014                         << "\n");
1015 
1016       // Isolate the calls to merge in a separate block.
1017       EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1018       BasicBlock *AfterBB =
1019           SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1020       StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1021                            "omp.par.merged");
1022 
1023       assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1024       const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1025       BB->getTerminator()->eraseFromParent();
1026 
1027       // Create sequential regions for sequential instructions that are
1028       // in-between mergable parallel regions.
1029       for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1030            It != End; ++It) {
1031         Instruction *ForkCI = *It;
1032         Instruction *NextForkCI = *(It + 1);
1033 
1034         // Continue if there are not in-between instructions.
1035         if (ForkCI->getNextNode() == NextForkCI)
1036           continue;
1037 
1038         CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1039                                NextForkCI->getPrevNode());
1040       }
1041 
1042       OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1043                                                DL);
1044       IRBuilder<>::InsertPoint AllocaIP(
1045           &OriginalFn->getEntryBlock(),
1046           OriginalFn->getEntryBlock().getFirstInsertionPt());
1047       // Create the merged parallel region with default proc binding, to
1048       // avoid overriding binding settings, and without explicit cancellation.
1049       InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1050           Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1051           OMP_PROC_BIND_default, /* IsCancellable */ false);
1052       BranchInst::Create(AfterBB, AfterIP.getBlock());
1053 
1054       // Perform the actual outlining.
1055       OMPInfoCache.OMPBuilder.finalize(OriginalFn,
1056                                        /* AllowExtractorSinking */ true);
1057 
1058       Function *OutlinedFn = MergableCIs.front()->getCaller();
1059 
1060       // Replace the __kmpc_fork_call calls with direct calls to the outlined
1061       // callbacks.
1062       SmallVector<Value *, 8> Args;
1063       for (auto *CI : MergableCIs) {
1064         Value *Callee =
1065             CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts();
1066         FunctionType *FT =
1067             cast<FunctionType>(Callee->getType()->getPointerElementType());
1068         Args.clear();
1069         Args.push_back(OutlinedFn->getArg(0));
1070         Args.push_back(OutlinedFn->getArg(1));
1071         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
1072              U < E; ++U)
1073           Args.push_back(CI->getArgOperand(U));
1074 
1075         CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
1076         if (CI->getDebugLoc())
1077           NewCI->setDebugLoc(CI->getDebugLoc());
1078 
1079         // Forward parameter attributes from the callback to the callee.
1080         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
1081              U < E; ++U)
1082           for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1083             NewCI->addParamAttr(
1084                 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1085 
1086         // Emit an explicit barrier to replace the implicit fork-join barrier.
1087         if (CI != MergableCIs.back()) {
1088           // TODO: Remove barrier if the merged parallel region includes the
1089           // 'nowait' clause.
1090           OMPInfoCache.OMPBuilder.createBarrier(
1091               InsertPointTy(NewCI->getParent(),
1092                             NewCI->getNextNode()->getIterator()),
1093               OMPD_parallel);
1094         }
1095 
1096         CI->eraseFromParent();
1097       }
1098 
1099       assert(OutlinedFn != OriginalFn && "Outlining failed");
1100       CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1101       CGUpdater.reanalyzeFunction(*OriginalFn);
1102 
1103       NumOpenMPParallelRegionsMerged += MergableCIs.size();
1104 
1105       return true;
1106     };
1107 
1108     // Helper function that identifes sequences of
1109     // __kmpc_fork_call uses in a basic block.
1110     auto DetectPRsCB = [&](Use &U, Function &F) {
1111       CallInst *CI = getCallIfRegularCall(U, &RFI);
1112       BB2PRMap[CI->getParent()].insert(CI);
1113 
1114       return false;
1115     };
1116 
1117     BB2PRMap.clear();
1118     RFI.foreachUse(SCC, DetectPRsCB);
1119     SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1120     // Find mergable parallel regions within a basic block that are
1121     // safe to merge, that is any in-between instructions can safely
1122     // execute in parallel after merging.
1123     // TODO: support merging across basic-blocks.
1124     for (auto &It : BB2PRMap) {
1125       auto &CIs = It.getSecond();
1126       if (CIs.size() < 2)
1127         continue;
1128 
1129       BasicBlock *BB = It.getFirst();
1130       SmallVector<CallInst *, 4> MergableCIs;
1131 
1132       /// Returns true if the instruction is mergable, false otherwise.
1133       /// A terminator instruction is unmergable by definition since merging
1134       /// works within a BB. Instructions before the mergable region are
1135       /// mergable if they are not calls to OpenMP runtime functions that may
1136       /// set different execution parameters for subsequent parallel regions.
1137       /// Instructions in-between parallel regions are mergable if they are not
1138       /// calls to any non-intrinsic function since that may call a non-mergable
1139       /// OpenMP runtime function.
1140       auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1141         // We do not merge across BBs, hence return false (unmergable) if the
1142         // instruction is a terminator.
1143         if (I.isTerminator())
1144           return false;
1145 
1146         if (!isa<CallInst>(&I))
1147           return true;
1148 
1149         CallInst *CI = cast<CallInst>(&I);
1150         if (IsBeforeMergableRegion) {
1151           Function *CalledFunction = CI->getCalledFunction();
1152           if (!CalledFunction)
1153             return false;
1154           // Return false (unmergable) if the call before the parallel
1155           // region calls an explicit affinity (proc_bind) or number of
1156           // threads (num_threads) compiler-generated function. Those settings
1157           // may be incompatible with following parallel regions.
1158           // TODO: ICV tracking to detect compatibility.
1159           for (const auto &RFI : UnmergableCallsInfo) {
1160             if (CalledFunction == RFI.Declaration)
1161               return false;
1162           }
1163         } else {
1164           // Return false (unmergable) if there is a call instruction
1165           // in-between parallel regions when it is not an intrinsic. It
1166           // may call an unmergable OpenMP runtime function in its callpath.
1167           // TODO: Keep track of possible OpenMP calls in the callpath.
1168           if (!isa<IntrinsicInst>(CI))
1169             return false;
1170         }
1171 
1172         return true;
1173       };
1174       // Find maximal number of parallel region CIs that are safe to merge.
1175       for (auto It = BB->begin(), End = BB->end(); It != End;) {
1176         Instruction &I = *It;
1177         ++It;
1178 
1179         if (CIs.count(&I)) {
1180           MergableCIs.push_back(cast<CallInst>(&I));
1181           continue;
1182         }
1183 
1184         // Continue expanding if the instruction is mergable.
1185         if (IsMergable(I, MergableCIs.empty()))
1186           continue;
1187 
1188         // Forward the instruction iterator to skip the next parallel region
1189         // since there is an unmergable instruction which can affect it.
1190         for (; It != End; ++It) {
1191           Instruction &SkipI = *It;
1192           if (CIs.count(&SkipI)) {
1193             LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1194                               << " due to " << I << "\n");
1195             ++It;
1196             break;
1197           }
1198         }
1199 
1200         // Store mergable regions found.
1201         if (MergableCIs.size() > 1) {
1202           MergableCIsVector.push_back(MergableCIs);
1203           LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1204                             << " parallel regions in block " << BB->getName()
1205                             << " of function " << BB->getParent()->getName()
1206                             << "\n";);
1207         }
1208 
1209         MergableCIs.clear();
1210       }
1211 
1212       if (!MergableCIsVector.empty()) {
1213         Changed = true;
1214 
1215         for (auto &MergableCIs : MergableCIsVector)
1216           Merge(MergableCIs, BB);
1217         MergableCIsVector.clear();
1218       }
1219     }
1220 
1221     if (Changed) {
1222       /// Re-collect use for fork calls, emitted barrier calls, and
1223       /// any emitted master/end_master calls.
1224       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1225       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1226       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1227       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1228     }
1229 
1230     return Changed;
1231   }
1232 
1233   /// Try to delete parallel regions if possible.
1234   bool deleteParallelRegions() {
1235     const unsigned CallbackCalleeOperand = 2;
1236 
1237     OMPInformationCache::RuntimeFunctionInfo &RFI =
1238         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1239 
1240     if (!RFI.Declaration)
1241       return false;
1242 
1243     bool Changed = false;
1244     auto DeleteCallCB = [&](Use &U, Function &) {
1245       CallInst *CI = getCallIfRegularCall(U);
1246       if (!CI)
1247         return false;
1248       auto *Fn = dyn_cast<Function>(
1249           CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1250       if (!Fn)
1251         return false;
1252       if (!Fn->onlyReadsMemory())
1253         return false;
1254       if (!Fn->hasFnAttribute(Attribute::WillReturn))
1255         return false;
1256 
1257       LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1258                         << CI->getCaller()->getName() << "\n");
1259 
1260       auto Remark = [&](OptimizationRemark OR) {
1261         return OR << "Removing parallel region with no side-effects.";
1262       };
1263       emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1264 
1265       CGUpdater.removeCallSite(*CI);
1266       CI->eraseFromParent();
1267       Changed = true;
1268       ++NumOpenMPParallelRegionsDeleted;
1269       return true;
1270     };
1271 
1272     RFI.foreachUse(SCC, DeleteCallCB);
1273 
1274     return Changed;
1275   }
1276 
1277   /// Try to eliminate runtime calls by reusing existing ones.
1278   bool deduplicateRuntimeCalls() {
1279     bool Changed = false;
1280 
1281     RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1282         OMPRTL_omp_get_num_threads,
1283         OMPRTL_omp_in_parallel,
1284         OMPRTL_omp_get_cancellation,
1285         OMPRTL_omp_get_thread_limit,
1286         OMPRTL_omp_get_supported_active_levels,
1287         OMPRTL_omp_get_level,
1288         OMPRTL_omp_get_ancestor_thread_num,
1289         OMPRTL_omp_get_team_size,
1290         OMPRTL_omp_get_active_level,
1291         OMPRTL_omp_in_final,
1292         OMPRTL_omp_get_proc_bind,
1293         OMPRTL_omp_get_num_places,
1294         OMPRTL_omp_get_num_procs,
1295         OMPRTL_omp_get_place_num,
1296         OMPRTL_omp_get_partition_num_places,
1297         OMPRTL_omp_get_partition_place_nums};
1298 
1299     // Global-tid is handled separately.
1300     SmallSetVector<Value *, 16> GTIdArgs;
1301     collectGlobalThreadIdArguments(GTIdArgs);
1302     LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1303                       << " global thread ID arguments\n");
1304 
1305     for (Function *F : SCC) {
1306       for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1307         Changed |= deduplicateRuntimeCalls(
1308             *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1309 
1310       // __kmpc_global_thread_num is special as we can replace it with an
1311       // argument in enough cases to make it worth trying.
1312       Value *GTIdArg = nullptr;
1313       for (Argument &Arg : F->args())
1314         if (GTIdArgs.count(&Arg)) {
1315           GTIdArg = &Arg;
1316           break;
1317         }
1318       Changed |= deduplicateRuntimeCalls(
1319           *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1320     }
1321 
1322     return Changed;
1323   }
1324 
1325   /// Tries to hide the latency of runtime calls that involve host to
1326   /// device memory transfers by splitting them into their "issue" and "wait"
1327   /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1328   /// moved downards as much as possible. The "issue" issues the memory transfer
1329   /// asynchronously, returning a handle. The "wait" waits in the returned
1330   /// handle for the memory transfer to finish.
1331   bool hideMemTransfersLatency() {
1332     auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1333     bool Changed = false;
1334     auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1335       auto *RTCall = getCallIfRegularCall(U, &RFI);
1336       if (!RTCall)
1337         return false;
1338 
1339       OffloadArray OffloadArrays[3];
1340       if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1341         return false;
1342 
1343       LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1344 
1345       // TODO: Check if can be moved upwards.
1346       bool WasSplit = false;
1347       Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1348       if (WaitMovementPoint)
1349         WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1350 
1351       Changed |= WasSplit;
1352       return WasSplit;
1353     };
1354     RFI.foreachUse(SCC, SplitMemTransfers);
1355 
1356     return Changed;
1357   }
1358 
1359   void analysisGlobalization() {
1360     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1361 
1362     auto CheckGlobalization = [&](Use &U, Function &Decl) {
1363       if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1364         auto Remark = [&](OptimizationRemarkMissed ORM) {
1365           return ORM
1366                  << "Found thread data sharing on the GPU. "
1367                  << "Expect degraded performance due to data globalization.";
1368         };
1369         emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1370       }
1371 
1372       return false;
1373     };
1374 
1375     RFI.foreachUse(SCC, CheckGlobalization);
1376   }
1377 
1378   /// Maps the values stored in the offload arrays passed as arguments to
1379   /// \p RuntimeCall into the offload arrays in \p OAs.
1380   bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1381                                 MutableArrayRef<OffloadArray> OAs) {
1382     assert(OAs.size() == 3 && "Need space for three offload arrays!");
1383 
1384     // A runtime call that involves memory offloading looks something like:
1385     // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1386     //   i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1387     // ...)
1388     // So, the idea is to access the allocas that allocate space for these
1389     // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1390     // Therefore:
1391     // i8** %offload_baseptrs.
1392     Value *BasePtrsArg =
1393         RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1394     // i8** %offload_ptrs.
1395     Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1396     // i8** %offload_sizes.
1397     Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1398 
1399     // Get values stored in **offload_baseptrs.
1400     auto *V = getUnderlyingObject(BasePtrsArg);
1401     if (!isa<AllocaInst>(V))
1402       return false;
1403     auto *BasePtrsArray = cast<AllocaInst>(V);
1404     if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1405       return false;
1406 
1407     // Get values stored in **offload_baseptrs.
1408     V = getUnderlyingObject(PtrsArg);
1409     if (!isa<AllocaInst>(V))
1410       return false;
1411     auto *PtrsArray = cast<AllocaInst>(V);
1412     if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1413       return false;
1414 
1415     // Get values stored in **offload_sizes.
1416     V = getUnderlyingObject(SizesArg);
1417     // If it's a [constant] global array don't analyze it.
1418     if (isa<GlobalValue>(V))
1419       return isa<Constant>(V);
1420     if (!isa<AllocaInst>(V))
1421       return false;
1422 
1423     auto *SizesArray = cast<AllocaInst>(V);
1424     if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1425       return false;
1426 
1427     return true;
1428   }
1429 
1430   /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1431   /// For now this is a way to test that the function getValuesInOffloadArrays
1432   /// is working properly.
1433   /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1434   void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1435     assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1436 
1437     LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1438     std::string ValuesStr;
1439     raw_string_ostream Printer(ValuesStr);
1440     std::string Separator = " --- ";
1441 
1442     for (auto *BP : OAs[0].StoredValues) {
1443       BP->print(Printer);
1444       Printer << Separator;
1445     }
1446     LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
1447     ValuesStr.clear();
1448 
1449     for (auto *P : OAs[1].StoredValues) {
1450       P->print(Printer);
1451       Printer << Separator;
1452     }
1453     LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
1454     ValuesStr.clear();
1455 
1456     for (auto *S : OAs[2].StoredValues) {
1457       S->print(Printer);
1458       Printer << Separator;
1459     }
1460     LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
1461   }
1462 
1463   /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1464   /// moved. Returns nullptr if the movement is not possible, or not worth it.
1465   Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1466     // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1467     //  Make it traverse the CFG.
1468 
1469     Instruction *CurrentI = &RuntimeCall;
1470     bool IsWorthIt = false;
1471     while ((CurrentI = CurrentI->getNextNode())) {
1472 
1473       // TODO: Once we detect the regions to be offloaded we should use the
1474       //  alias analysis manager to check if CurrentI may modify one of
1475       //  the offloaded regions.
1476       if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1477         if (IsWorthIt)
1478           return CurrentI;
1479 
1480         return nullptr;
1481       }
1482 
1483       // FIXME: For now if we move it over anything without side effect
1484       //  is worth it.
1485       IsWorthIt = true;
1486     }
1487 
1488     // Return end of BasicBlock.
1489     return RuntimeCall.getParent()->getTerminator();
1490   }
1491 
1492   /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1493   bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1494                                Instruction &WaitMovementPoint) {
1495     // Create stack allocated handle (__tgt_async_info) at the beginning of the
1496     // function. Used for storing information of the async transfer, allowing to
1497     // wait on it later.
1498     auto &IRBuilder = OMPInfoCache.OMPBuilder;
1499     auto *F = RuntimeCall.getCaller();
1500     Instruction *FirstInst = &(F->getEntryBlock().front());
1501     AllocaInst *Handle = new AllocaInst(
1502         IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst);
1503 
1504     // Add "issue" runtime call declaration:
1505     // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1506     //   i8**, i8**, i64*, i64*)
1507     FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1508         M, OMPRTL___tgt_target_data_begin_mapper_issue);
1509 
1510     // Change RuntimeCall call site for its asynchronous version.
1511     SmallVector<Value *, 16> Args;
1512     for (auto &Arg : RuntimeCall.args())
1513       Args.push_back(Arg.get());
1514     Args.push_back(Handle);
1515 
1516     CallInst *IssueCallsite =
1517         CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
1518     RuntimeCall.eraseFromParent();
1519 
1520     // Add "wait" runtime call declaration:
1521     // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1522     FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1523         M, OMPRTL___tgt_target_data_begin_mapper_wait);
1524 
1525     Value *WaitParams[2] = {
1526         IssueCallsite->getArgOperand(
1527             OffloadArray::DeviceIDArgNum), // device_id.
1528         Handle                             // handle to wait on.
1529     };
1530     CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
1531 
1532     return true;
1533   }
1534 
1535   static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1536                                     bool GlobalOnly, bool &SingleChoice) {
1537     if (CurrentIdent == NextIdent)
1538       return CurrentIdent;
1539 
1540     // TODO: Figure out how to actually combine multiple debug locations. For
1541     //       now we just keep an existing one if there is a single choice.
1542     if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1543       SingleChoice = !CurrentIdent;
1544       return NextIdent;
1545     }
1546     return nullptr;
1547   }
1548 
1549   /// Return an `struct ident_t*` value that represents the ones used in the
1550   /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1551   /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1552   /// return value we create one from scratch. We also do not yet combine
1553   /// information, e.g., the source locations, see combinedIdentStruct.
1554   Value *
1555   getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1556                                  Function &F, bool GlobalOnly) {
1557     bool SingleChoice = true;
1558     Value *Ident = nullptr;
1559     auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1560       CallInst *CI = getCallIfRegularCall(U, &RFI);
1561       if (!CI || &F != &Caller)
1562         return false;
1563       Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1564                                   /* GlobalOnly */ true, SingleChoice);
1565       return false;
1566     };
1567     RFI.foreachUse(SCC, CombineIdentStruct);
1568 
1569     if (!Ident || !SingleChoice) {
1570       // The IRBuilder uses the insertion block to get to the module, this is
1571       // unfortunate but we work around it for now.
1572       if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1573         OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1574             &F.getEntryBlock(), F.getEntryBlock().begin()));
1575       // Create a fallback location if non was found.
1576       // TODO: Use the debug locations of the calls instead.
1577       Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr();
1578       Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc);
1579     }
1580     return Ident;
1581   }
1582 
1583   /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1584   /// \p ReplVal if given.
1585   bool deduplicateRuntimeCalls(Function &F,
1586                                OMPInformationCache::RuntimeFunctionInfo &RFI,
1587                                Value *ReplVal = nullptr) {
1588     auto *UV = RFI.getUseVector(F);
1589     if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1590       return false;
1591 
1592     LLVM_DEBUG(
1593         dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1594                << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1595 
1596     assert((!ReplVal || (isa<Argument>(ReplVal) &&
1597                          cast<Argument>(ReplVal)->getParent() == &F)) &&
1598            "Unexpected replacement value!");
1599 
1600     // TODO: Use dominance to find a good position instead.
1601     auto CanBeMoved = [this](CallBase &CB) {
1602       unsigned NumArgs = CB.getNumArgOperands();
1603       if (NumArgs == 0)
1604         return true;
1605       if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1606         return false;
1607       for (unsigned u = 1; u < NumArgs; ++u)
1608         if (isa<Instruction>(CB.getArgOperand(u)))
1609           return false;
1610       return true;
1611     };
1612 
1613     if (!ReplVal) {
1614       for (Use *U : *UV)
1615         if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1616           if (!CanBeMoved(*CI))
1617             continue;
1618 
1619           // If the function is a kernel, dedup will move
1620           // the runtime call right after the kernel init callsite. Otherwise,
1621           // it will move it to the beginning of the caller function.
1622           if (isKernel(F)) {
1623             auto &KernelInitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
1624             auto *KernelInitUV = KernelInitRFI.getUseVector(F);
1625 
1626             if (KernelInitUV->empty())
1627               continue;
1628 
1629             assert(KernelInitUV->size() == 1 &&
1630                    "Expected a single __kmpc_target_init in kernel\n");
1631 
1632             CallInst *KernelInitCI =
1633                 getCallIfRegularCall(*KernelInitUV->front(), &KernelInitRFI);
1634             assert(KernelInitCI &&
1635                    "Expected a call to __kmpc_target_init in kernel\n");
1636 
1637             CI->moveAfter(KernelInitCI);
1638           } else
1639             CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
1640           ReplVal = CI;
1641           break;
1642         }
1643       if (!ReplVal)
1644         return false;
1645     }
1646 
1647     // If we use a call as a replacement value we need to make sure the ident is
1648     // valid at the new location. For now we just pick a global one, either
1649     // existing and used by one of the calls, or created from scratch.
1650     if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1651       if (!CI->arg_empty() &&
1652           CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1653         Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1654                                                       /* GlobalOnly */ true);
1655         CI->setArgOperand(0, Ident);
1656       }
1657     }
1658 
1659     bool Changed = false;
1660     auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1661       CallInst *CI = getCallIfRegularCall(U, &RFI);
1662       if (!CI || CI == ReplVal || &F != &Caller)
1663         return false;
1664       assert(CI->getCaller() == &F && "Unexpected call!");
1665 
1666       auto Remark = [&](OptimizationRemark OR) {
1667         return OR << "OpenMP runtime call "
1668                   << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1669       };
1670       if (CI->getDebugLoc())
1671         emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1672       else
1673         emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1674 
1675       CGUpdater.removeCallSite(*CI);
1676       CI->replaceAllUsesWith(ReplVal);
1677       CI->eraseFromParent();
1678       ++NumOpenMPRuntimeCallsDeduplicated;
1679       Changed = true;
1680       return true;
1681     };
1682     RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1683 
1684     return Changed;
1685   }
1686 
1687   /// Collect arguments that represent the global thread id in \p GTIdArgs.
1688   void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1689     // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1690     //       initialization. We could define an AbstractAttribute instead and
1691     //       run the Attributor here once it can be run as an SCC pass.
1692 
1693     // Helper to check the argument \p ArgNo at all call sites of \p F for
1694     // a GTId.
1695     auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1696       if (!F.hasLocalLinkage())
1697         return false;
1698       for (Use &U : F.uses()) {
1699         if (CallInst *CI = getCallIfRegularCall(U)) {
1700           Value *ArgOp = CI->getArgOperand(ArgNo);
1701           if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1702               getCallIfRegularCall(
1703                   *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1704             continue;
1705         }
1706         return false;
1707       }
1708       return true;
1709     };
1710 
1711     // Helper to identify uses of a GTId as GTId arguments.
1712     auto AddUserArgs = [&](Value &GTId) {
1713       for (Use &U : GTId.uses())
1714         if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1715           if (CI->isArgOperand(&U))
1716             if (Function *Callee = CI->getCalledFunction())
1717               if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1718                 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1719     };
1720 
1721     // The argument users of __kmpc_global_thread_num calls are GTIds.
1722     OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1723         OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1724 
1725     GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1726       if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1727         AddUserArgs(*CI);
1728       return false;
1729     });
1730 
1731     // Transitively search for more arguments by looking at the users of the
1732     // ones we know already. During the search the GTIdArgs vector is extended
1733     // so we cannot cache the size nor can we use a range based for.
1734     for (unsigned u = 0; u < GTIdArgs.size(); ++u)
1735       AddUserArgs(*GTIdArgs[u]);
1736   }
1737 
1738   /// Kernel (=GPU) optimizations and utility functions
1739   ///
1740   ///{{
1741 
1742   /// Check if \p F is a kernel, hence entry point for target offloading.
1743   bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
1744 
1745   /// Cache to remember the unique kernel for a function.
1746   DenseMap<Function *, Optional<Kernel>> UniqueKernelMap;
1747 
1748   /// Find the unique kernel that will execute \p F, if any.
1749   Kernel getUniqueKernelFor(Function &F);
1750 
1751   /// Find the unique kernel that will execute \p I, if any.
1752   Kernel getUniqueKernelFor(Instruction &I) {
1753     return getUniqueKernelFor(*I.getFunction());
1754   }
1755 
1756   /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1757   /// the cases we can avoid taking the address of a function.
1758   bool rewriteDeviceCodeStateMachine();
1759 
1760   ///
1761   ///}}
1762 
1763   /// Emit a remark generically
1764   ///
1765   /// This template function can be used to generically emit a remark. The
1766   /// RemarkKind should be one of the following:
1767   ///   - OptimizationRemark to indicate a successful optimization attempt
1768   ///   - OptimizationRemarkMissed to report a failed optimization attempt
1769   ///   - OptimizationRemarkAnalysis to provide additional information about an
1770   ///     optimization attempt
1771   ///
1772   /// The remark is built using a callback function provided by the caller that
1773   /// takes a RemarkKind as input and returns a RemarkKind.
1774   template <typename RemarkKind, typename RemarkCallBack>
1775   void emitRemark(Instruction *I, StringRef RemarkName,
1776                   RemarkCallBack &&RemarkCB) const {
1777     Function *F = I->getParent()->getParent();
1778     auto &ORE = OREGetter(F);
1779 
1780     if (RemarkName.startswith("OMP"))
1781       ORE.emit([&]() {
1782         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
1783                << " [" << RemarkName << "]";
1784       });
1785     else
1786       ORE.emit(
1787           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
1788   }
1789 
1790   /// Emit a remark on a function.
1791   template <typename RemarkKind, typename RemarkCallBack>
1792   void emitRemark(Function *F, StringRef RemarkName,
1793                   RemarkCallBack &&RemarkCB) const {
1794     auto &ORE = OREGetter(F);
1795 
1796     if (RemarkName.startswith("OMP"))
1797       ORE.emit([&]() {
1798         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
1799                << " [" << RemarkName << "]";
1800       });
1801     else
1802       ORE.emit(
1803           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
1804   }
1805 
1806   /// RAII struct to temporarily change an RTL function's linkage to external.
1807   /// This prevents it from being mistakenly removed by other optimizations.
1808   struct ExternalizationRAII {
1809     ExternalizationRAII(OMPInformationCache &OMPInfoCache,
1810                         RuntimeFunction RFKind)
1811         : Declaration(OMPInfoCache.RFIs[RFKind].Declaration) {
1812       if (!Declaration)
1813         return;
1814 
1815       LinkageType = Declaration->getLinkage();
1816       Declaration->setLinkage(GlobalValue::ExternalLinkage);
1817     }
1818 
1819     ~ExternalizationRAII() {
1820       if (!Declaration)
1821         return;
1822 
1823       Declaration->setLinkage(LinkageType);
1824     }
1825 
1826     Function *Declaration;
1827     GlobalValue::LinkageTypes LinkageType;
1828   };
1829 
1830   /// The underlying module.
1831   Module &M;
1832 
1833   /// The SCC we are operating on.
1834   SmallVectorImpl<Function *> &SCC;
1835 
1836   /// Callback to update the call graph, the first argument is a removed call,
1837   /// the second an optional replacement call.
1838   CallGraphUpdater &CGUpdater;
1839 
1840   /// Callback to get an OptimizationRemarkEmitter from a Function *
1841   OptimizationRemarkGetter OREGetter;
1842 
1843   /// OpenMP-specific information cache. Also Used for Attributor runs.
1844   OMPInformationCache &OMPInfoCache;
1845 
1846   /// Attributor instance.
1847   Attributor &A;
1848 
1849   /// Helper function to run Attributor on SCC.
1850   bool runAttributor(bool IsModulePass) {
1851     if (SCC.empty())
1852       return false;
1853 
1854     // Temporarily make these function have external linkage so the Attributor
1855     // doesn't remove them when we try to look them up later.
1856     ExternalizationRAII Parallel(OMPInfoCache, OMPRTL___kmpc_kernel_parallel);
1857     ExternalizationRAII EndParallel(OMPInfoCache,
1858                                     OMPRTL___kmpc_kernel_end_parallel);
1859     ExternalizationRAII BarrierSPMD(OMPInfoCache,
1860                                     OMPRTL___kmpc_barrier_simple_spmd);
1861 
1862     registerAAs(IsModulePass);
1863 
1864     ChangeStatus Changed = A.run();
1865 
1866     LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
1867                       << " functions, result: " << Changed << ".\n");
1868 
1869     return Changed == ChangeStatus::CHANGED;
1870   }
1871 
1872   void registerFoldRuntimeCall(RuntimeFunction RF);
1873 
1874   /// Populate the Attributor with abstract attribute opportunities in the
1875   /// function.
1876   void registerAAs(bool IsModulePass);
1877 };
1878 
1879 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
1880   if (!OMPInfoCache.ModuleSlice.count(&F))
1881     return nullptr;
1882 
1883   // Use a scope to keep the lifetime of the CachedKernel short.
1884   {
1885     Optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
1886     if (CachedKernel)
1887       return *CachedKernel;
1888 
1889     // TODO: We should use an AA to create an (optimistic and callback
1890     //       call-aware) call graph. For now we stick to simple patterns that
1891     //       are less powerful, basically the worst fixpoint.
1892     if (isKernel(F)) {
1893       CachedKernel = Kernel(&F);
1894       return *CachedKernel;
1895     }
1896 
1897     CachedKernel = nullptr;
1898     if (!F.hasLocalLinkage()) {
1899 
1900       // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
1901       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1902         return ORA << "Potentially unknown OpenMP target region caller.";
1903       };
1904       emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
1905 
1906       return nullptr;
1907     }
1908   }
1909 
1910   auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
1911     if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
1912       // Allow use in equality comparisons.
1913       if (Cmp->isEquality())
1914         return getUniqueKernelFor(*Cmp);
1915       return nullptr;
1916     }
1917     if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
1918       // Allow direct calls.
1919       if (CB->isCallee(&U))
1920         return getUniqueKernelFor(*CB);
1921 
1922       OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1923           OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1924       // Allow the use in __kmpc_parallel_51 calls.
1925       if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
1926         return getUniqueKernelFor(*CB);
1927       return nullptr;
1928     }
1929     // Disallow every other use.
1930     return nullptr;
1931   };
1932 
1933   // TODO: In the future we want to track more than just a unique kernel.
1934   SmallPtrSet<Kernel, 2> PotentialKernels;
1935   OMPInformationCache::foreachUse(F, [&](const Use &U) {
1936     PotentialKernels.insert(GetUniqueKernelForUse(U));
1937   });
1938 
1939   Kernel K = nullptr;
1940   if (PotentialKernels.size() == 1)
1941     K = *PotentialKernels.begin();
1942 
1943   // Cache the result.
1944   UniqueKernelMap[&F] = K;
1945 
1946   return K;
1947 }
1948 
1949 bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
1950   OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1951       OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1952 
1953   bool Changed = false;
1954   if (!KernelParallelRFI)
1955     return Changed;
1956 
1957   // If we have disabled state machine changes, exit
1958   if (DisableOpenMPOptStateMachineRewrite)
1959     return Changed;
1960 
1961   for (Function *F : SCC) {
1962 
1963     // Check if the function is a use in a __kmpc_parallel_51 call at
1964     // all.
1965     bool UnknownUse = false;
1966     bool KernelParallelUse = false;
1967     unsigned NumDirectCalls = 0;
1968 
1969     SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
1970     OMPInformationCache::foreachUse(*F, [&](Use &U) {
1971       if (auto *CB = dyn_cast<CallBase>(U.getUser()))
1972         if (CB->isCallee(&U)) {
1973           ++NumDirectCalls;
1974           return;
1975         }
1976 
1977       if (isa<ICmpInst>(U.getUser())) {
1978         ToBeReplacedStateMachineUses.push_back(&U);
1979         return;
1980       }
1981 
1982       // Find wrapper functions that represent parallel kernels.
1983       CallInst *CI =
1984           OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
1985       const unsigned int WrapperFunctionArgNo = 6;
1986       if (!KernelParallelUse && CI &&
1987           CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
1988         KernelParallelUse = true;
1989         ToBeReplacedStateMachineUses.push_back(&U);
1990         return;
1991       }
1992       UnknownUse = true;
1993     });
1994 
1995     // Do not emit a remark if we haven't seen a __kmpc_parallel_51
1996     // use.
1997     if (!KernelParallelUse)
1998       continue;
1999 
2000     // If this ever hits, we should investigate.
2001     // TODO: Checking the number of uses is not a necessary restriction and
2002     // should be lifted.
2003     if (UnknownUse || NumDirectCalls != 1 ||
2004         ToBeReplacedStateMachineUses.size() > 2) {
2005       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2006         return ORA << "Parallel region is used in "
2007                    << (UnknownUse ? "unknown" : "unexpected")
2008                    << " ways. Will not attempt to rewrite the state machine.";
2009       };
2010       emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2011       continue;
2012     }
2013 
2014     // Even if we have __kmpc_parallel_51 calls, we (for now) give
2015     // up if the function is not called from a unique kernel.
2016     Kernel K = getUniqueKernelFor(*F);
2017     if (!K) {
2018       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2019         return ORA << "Parallel region is not called from a unique kernel. "
2020                       "Will not attempt to rewrite the state machine.";
2021       };
2022       emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2023       continue;
2024     }
2025 
2026     // We now know F is a parallel body function called only from the kernel K.
2027     // We also identified the state machine uses in which we replace the
2028     // function pointer by a new global symbol for identification purposes. This
2029     // ensures only direct calls to the function are left.
2030 
2031     Module &M = *F->getParent();
2032     Type *Int8Ty = Type::getInt8Ty(M.getContext());
2033 
2034     auto *ID = new GlobalVariable(
2035         M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2036         UndefValue::get(Int8Ty), F->getName() + ".ID");
2037 
2038     for (Use *U : ToBeReplacedStateMachineUses)
2039       U->set(ConstantExpr::getBitCast(ID, U->get()->getType()));
2040 
2041     ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2042 
2043     Changed = true;
2044   }
2045 
2046   return Changed;
2047 }
2048 
2049 /// Abstract Attribute for tracking ICV values.
2050 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2051   using Base = StateWrapper<BooleanState, AbstractAttribute>;
2052   AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2053 
2054   void initialize(Attributor &A) override {
2055     Function *F = getAnchorScope();
2056     if (!F || !A.isFunctionIPOAmendable(*F))
2057       indicatePessimisticFixpoint();
2058   }
2059 
2060   /// Returns true if value is assumed to be tracked.
2061   bool isAssumedTracked() const { return getAssumed(); }
2062 
2063   /// Returns true if value is known to be tracked.
2064   bool isKnownTracked() const { return getAssumed(); }
2065 
2066   /// Create an abstract attribute biew for the position \p IRP.
2067   static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2068 
2069   /// Return the value with which \p I can be replaced for specific \p ICV.
2070   virtual Optional<Value *> getReplacementValue(InternalControlVar ICV,
2071                                                 const Instruction *I,
2072                                                 Attributor &A) const {
2073     return None;
2074   }
2075 
2076   /// Return an assumed unique ICV value if a single candidate is found. If
2077   /// there cannot be one, return a nullptr. If it is not clear yet, return the
2078   /// Optional::NoneType.
2079   virtual Optional<Value *>
2080   getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2081 
2082   // Currently only nthreads is being tracked.
2083   // this array will only grow with time.
2084   InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2085 
2086   /// See AbstractAttribute::getName()
2087   const std::string getName() const override { return "AAICVTracker"; }
2088 
2089   /// See AbstractAttribute::getIdAddr()
2090   const char *getIdAddr() const override { return &ID; }
2091 
2092   /// This function should return true if the type of the \p AA is AAICVTracker
2093   static bool classof(const AbstractAttribute *AA) {
2094     return (AA->getIdAddr() == &ID);
2095   }
2096 
2097   static const char ID;
2098 };
2099 
2100 struct AAICVTrackerFunction : public AAICVTracker {
2101   AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2102       : AAICVTracker(IRP, A) {}
2103 
2104   // FIXME: come up with better string.
2105   const std::string getAsStr() const override { return "ICVTrackerFunction"; }
2106 
2107   // FIXME: come up with some stats.
2108   void trackStatistics() const override {}
2109 
2110   /// We don't manifest anything for this AA.
2111   ChangeStatus manifest(Attributor &A) override {
2112     return ChangeStatus::UNCHANGED;
2113   }
2114 
2115   // Map of ICV to their values at specific program point.
2116   EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
2117                   InternalControlVar::ICV___last>
2118       ICVReplacementValuesMap;
2119 
2120   ChangeStatus updateImpl(Attributor &A) override {
2121     ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2122 
2123     Function *F = getAnchorScope();
2124 
2125     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2126 
2127     for (InternalControlVar ICV : TrackableICVs) {
2128       auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2129 
2130       auto &ValuesMap = ICVReplacementValuesMap[ICV];
2131       auto TrackValues = [&](Use &U, Function &) {
2132         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2133         if (!CI)
2134           return false;
2135 
2136         // FIXME: handle setters with more that 1 arguments.
2137         /// Track new value.
2138         if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2139           HasChanged = ChangeStatus::CHANGED;
2140 
2141         return false;
2142       };
2143 
2144       auto CallCheck = [&](Instruction &I) {
2145         Optional<Value *> ReplVal = getValueForCall(A, &I, ICV);
2146         if (ReplVal.hasValue() &&
2147             ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2148           HasChanged = ChangeStatus::CHANGED;
2149 
2150         return true;
2151       };
2152 
2153       // Track all changes of an ICV.
2154       SetterRFI.foreachUse(TrackValues, F);
2155 
2156       bool UsedAssumedInformation = false;
2157       A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2158                                 UsedAssumedInformation,
2159                                 /* CheckBBLivenessOnly */ true);
2160 
2161       /// TODO: Figure out a way to avoid adding entry in
2162       /// ICVReplacementValuesMap
2163       Instruction *Entry = &F->getEntryBlock().front();
2164       if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2165         ValuesMap.insert(std::make_pair(Entry, nullptr));
2166     }
2167 
2168     return HasChanged;
2169   }
2170 
2171   /// Hepler to check if \p I is a call and get the value for it if it is
2172   /// unique.
2173   Optional<Value *> getValueForCall(Attributor &A, const Instruction *I,
2174                                     InternalControlVar &ICV) const {
2175 
2176     const auto *CB = dyn_cast<CallBase>(I);
2177     if (!CB || CB->hasFnAttr("no_openmp") ||
2178         CB->hasFnAttr("no_openmp_routines"))
2179       return None;
2180 
2181     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2182     auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2183     auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2184     Function *CalledFunction = CB->getCalledFunction();
2185 
2186     // Indirect call, assume ICV changes.
2187     if (CalledFunction == nullptr)
2188       return nullptr;
2189     if (CalledFunction == GetterRFI.Declaration)
2190       return None;
2191     if (CalledFunction == SetterRFI.Declaration) {
2192       if (ICVReplacementValuesMap[ICV].count(I))
2193         return ICVReplacementValuesMap[ICV].lookup(I);
2194 
2195       return nullptr;
2196     }
2197 
2198     // Since we don't know, assume it changes the ICV.
2199     if (CalledFunction->isDeclaration())
2200       return nullptr;
2201 
2202     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2203         *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2204 
2205     if (ICVTrackingAA.isAssumedTracked())
2206       return ICVTrackingAA.getUniqueReplacementValue(ICV);
2207 
2208     // If we don't know, assume it changes.
2209     return nullptr;
2210   }
2211 
2212   // We don't check unique value for a function, so return None.
2213   Optional<Value *>
2214   getUniqueReplacementValue(InternalControlVar ICV) const override {
2215     return None;
2216   }
2217 
2218   /// Return the value with which \p I can be replaced for specific \p ICV.
2219   Optional<Value *> getReplacementValue(InternalControlVar ICV,
2220                                         const Instruction *I,
2221                                         Attributor &A) const override {
2222     const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2223     if (ValuesMap.count(I))
2224       return ValuesMap.lookup(I);
2225 
2226     SmallVector<const Instruction *, 16> Worklist;
2227     SmallPtrSet<const Instruction *, 16> Visited;
2228     Worklist.push_back(I);
2229 
2230     Optional<Value *> ReplVal;
2231 
2232     while (!Worklist.empty()) {
2233       const Instruction *CurrInst = Worklist.pop_back_val();
2234       if (!Visited.insert(CurrInst).second)
2235         continue;
2236 
2237       const BasicBlock *CurrBB = CurrInst->getParent();
2238 
2239       // Go up and look for all potential setters/calls that might change the
2240       // ICV.
2241       while ((CurrInst = CurrInst->getPrevNode())) {
2242         if (ValuesMap.count(CurrInst)) {
2243           Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2244           // Unknown value, track new.
2245           if (!ReplVal.hasValue()) {
2246             ReplVal = NewReplVal;
2247             break;
2248           }
2249 
2250           // If we found a new value, we can't know the icv value anymore.
2251           if (NewReplVal.hasValue())
2252             if (ReplVal != NewReplVal)
2253               return nullptr;
2254 
2255           break;
2256         }
2257 
2258         Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV);
2259         if (!NewReplVal.hasValue())
2260           continue;
2261 
2262         // Unknown value, track new.
2263         if (!ReplVal.hasValue()) {
2264           ReplVal = NewReplVal;
2265           break;
2266         }
2267 
2268         // if (NewReplVal.hasValue())
2269         // We found a new value, we can't know the icv value anymore.
2270         if (ReplVal != NewReplVal)
2271           return nullptr;
2272       }
2273 
2274       // If we are in the same BB and we have a value, we are done.
2275       if (CurrBB == I->getParent() && ReplVal.hasValue())
2276         return ReplVal;
2277 
2278       // Go through all predecessors and add terminators for analysis.
2279       for (const BasicBlock *Pred : predecessors(CurrBB))
2280         if (const Instruction *Terminator = Pred->getTerminator())
2281           Worklist.push_back(Terminator);
2282     }
2283 
2284     return ReplVal;
2285   }
2286 };
2287 
2288 struct AAICVTrackerFunctionReturned : AAICVTracker {
2289   AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2290       : AAICVTracker(IRP, A) {}
2291 
2292   // FIXME: come up with better string.
2293   const std::string getAsStr() const override {
2294     return "ICVTrackerFunctionReturned";
2295   }
2296 
2297   // FIXME: come up with some stats.
2298   void trackStatistics() const override {}
2299 
2300   /// We don't manifest anything for this AA.
2301   ChangeStatus manifest(Attributor &A) override {
2302     return ChangeStatus::UNCHANGED;
2303   }
2304 
2305   // Map of ICV to their values at specific program point.
2306   EnumeratedArray<Optional<Value *>, InternalControlVar,
2307                   InternalControlVar::ICV___last>
2308       ICVReplacementValuesMap;
2309 
2310   /// Return the value with which \p I can be replaced for specific \p ICV.
2311   Optional<Value *>
2312   getUniqueReplacementValue(InternalControlVar ICV) const override {
2313     return ICVReplacementValuesMap[ICV];
2314   }
2315 
2316   ChangeStatus updateImpl(Attributor &A) override {
2317     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2318     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2319         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2320 
2321     if (!ICVTrackingAA.isAssumedTracked())
2322       return indicatePessimisticFixpoint();
2323 
2324     for (InternalControlVar ICV : TrackableICVs) {
2325       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2326       Optional<Value *> UniqueICVValue;
2327 
2328       auto CheckReturnInst = [&](Instruction &I) {
2329         Optional<Value *> NewReplVal =
2330             ICVTrackingAA.getReplacementValue(ICV, &I, A);
2331 
2332         // If we found a second ICV value there is no unique returned value.
2333         if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal)
2334           return false;
2335 
2336         UniqueICVValue = NewReplVal;
2337 
2338         return true;
2339       };
2340 
2341       bool UsedAssumedInformation = false;
2342       if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2343                                      UsedAssumedInformation,
2344                                      /* CheckBBLivenessOnly */ true))
2345         UniqueICVValue = nullptr;
2346 
2347       if (UniqueICVValue == ReplVal)
2348         continue;
2349 
2350       ReplVal = UniqueICVValue;
2351       Changed = ChangeStatus::CHANGED;
2352     }
2353 
2354     return Changed;
2355   }
2356 };
2357 
2358 struct AAICVTrackerCallSite : AAICVTracker {
2359   AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2360       : AAICVTracker(IRP, A) {}
2361 
2362   void initialize(Attributor &A) override {
2363     Function *F = getAnchorScope();
2364     if (!F || !A.isFunctionIPOAmendable(*F))
2365       indicatePessimisticFixpoint();
2366 
2367     // We only initialize this AA for getters, so we need to know which ICV it
2368     // gets.
2369     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2370     for (InternalControlVar ICV : TrackableICVs) {
2371       auto ICVInfo = OMPInfoCache.ICVs[ICV];
2372       auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2373       if (Getter.Declaration == getAssociatedFunction()) {
2374         AssociatedICV = ICVInfo.Kind;
2375         return;
2376       }
2377     }
2378 
2379     /// Unknown ICV.
2380     indicatePessimisticFixpoint();
2381   }
2382 
2383   ChangeStatus manifest(Attributor &A) override {
2384     if (!ReplVal.hasValue() || !ReplVal.getValue())
2385       return ChangeStatus::UNCHANGED;
2386 
2387     A.changeValueAfterManifest(*getCtxI(), **ReplVal);
2388     A.deleteAfterManifest(*getCtxI());
2389 
2390     return ChangeStatus::CHANGED;
2391   }
2392 
2393   // FIXME: come up with better string.
2394   const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
2395 
2396   // FIXME: come up with some stats.
2397   void trackStatistics() const override {}
2398 
2399   InternalControlVar AssociatedICV;
2400   Optional<Value *> ReplVal;
2401 
2402   ChangeStatus updateImpl(Attributor &A) override {
2403     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2404         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2405 
2406     // We don't have any information, so we assume it changes the ICV.
2407     if (!ICVTrackingAA.isAssumedTracked())
2408       return indicatePessimisticFixpoint();
2409 
2410     Optional<Value *> NewReplVal =
2411         ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
2412 
2413     if (ReplVal == NewReplVal)
2414       return ChangeStatus::UNCHANGED;
2415 
2416     ReplVal = NewReplVal;
2417     return ChangeStatus::CHANGED;
2418   }
2419 
2420   // Return the value with which associated value can be replaced for specific
2421   // \p ICV.
2422   Optional<Value *>
2423   getUniqueReplacementValue(InternalControlVar ICV) const override {
2424     return ReplVal;
2425   }
2426 };
2427 
2428 struct AAICVTrackerCallSiteReturned : AAICVTracker {
2429   AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2430       : AAICVTracker(IRP, A) {}
2431 
2432   // FIXME: come up with better string.
2433   const std::string getAsStr() const override {
2434     return "ICVTrackerCallSiteReturned";
2435   }
2436 
2437   // FIXME: come up with some stats.
2438   void trackStatistics() const override {}
2439 
2440   /// We don't manifest anything for this AA.
2441   ChangeStatus manifest(Attributor &A) override {
2442     return ChangeStatus::UNCHANGED;
2443   }
2444 
2445   // Map of ICV to their values at specific program point.
2446   EnumeratedArray<Optional<Value *>, InternalControlVar,
2447                   InternalControlVar::ICV___last>
2448       ICVReplacementValuesMap;
2449 
2450   /// Return the value with which associated value can be replaced for specific
2451   /// \p ICV.
2452   Optional<Value *>
2453   getUniqueReplacementValue(InternalControlVar ICV) const override {
2454     return ICVReplacementValuesMap[ICV];
2455   }
2456 
2457   ChangeStatus updateImpl(Attributor &A) override {
2458     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2459     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2460         *this, IRPosition::returned(*getAssociatedFunction()),
2461         DepClassTy::REQUIRED);
2462 
2463     // We don't have any information, so we assume it changes the ICV.
2464     if (!ICVTrackingAA.isAssumedTracked())
2465       return indicatePessimisticFixpoint();
2466 
2467     for (InternalControlVar ICV : TrackableICVs) {
2468       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2469       Optional<Value *> NewReplVal =
2470           ICVTrackingAA.getUniqueReplacementValue(ICV);
2471 
2472       if (ReplVal == NewReplVal)
2473         continue;
2474 
2475       ReplVal = NewReplVal;
2476       Changed = ChangeStatus::CHANGED;
2477     }
2478     return Changed;
2479   }
2480 };
2481 
2482 struct AAExecutionDomainFunction : public AAExecutionDomain {
2483   AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2484       : AAExecutionDomain(IRP, A) {}
2485 
2486   const std::string getAsStr() const override {
2487     return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) +
2488            "/" + std::to_string(NumBBs) + " BBs thread 0 only.";
2489   }
2490 
2491   /// See AbstractAttribute::trackStatistics().
2492   void trackStatistics() const override {}
2493 
2494   void initialize(Attributor &A) override {
2495     Function *F = getAnchorScope();
2496     for (const auto &BB : *F)
2497       SingleThreadedBBs.insert(&BB);
2498     NumBBs = SingleThreadedBBs.size();
2499   }
2500 
2501   ChangeStatus manifest(Attributor &A) override {
2502     LLVM_DEBUG({
2503       for (const BasicBlock *BB : SingleThreadedBBs)
2504         dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2505                << BB->getName() << " is executed by a single thread.\n";
2506     });
2507     return ChangeStatus::UNCHANGED;
2508   }
2509 
2510   ChangeStatus updateImpl(Attributor &A) override;
2511 
2512   /// Check if an instruction is executed by a single thread.
2513   bool isExecutedByInitialThreadOnly(const Instruction &I) const override {
2514     return isExecutedByInitialThreadOnly(*I.getParent());
2515   }
2516 
2517   bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2518     return isValidState() && SingleThreadedBBs.contains(&BB);
2519   }
2520 
2521   /// Set of basic blocks that are executed by a single thread.
2522   DenseSet<const BasicBlock *> SingleThreadedBBs;
2523 
2524   /// Total number of basic blocks in this function.
2525   long unsigned NumBBs;
2526 };
2527 
2528 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
2529   Function *F = getAnchorScope();
2530   ReversePostOrderTraversal<Function *> RPOT(F);
2531   auto NumSingleThreadedBBs = SingleThreadedBBs.size();
2532 
2533   bool AllCallSitesKnown;
2534   auto PredForCallSite = [&](AbstractCallSite ACS) {
2535     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
2536         *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
2537         DepClassTy::REQUIRED);
2538     return ACS.isDirectCall() &&
2539            ExecutionDomainAA.isExecutedByInitialThreadOnly(
2540                *ACS.getInstruction());
2541   };
2542 
2543   if (!A.checkForAllCallSites(PredForCallSite, *this,
2544                               /* RequiresAllCallSites */ true,
2545                               AllCallSitesKnown))
2546     SingleThreadedBBs.erase(&F->getEntryBlock());
2547 
2548   auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2549   auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2550 
2551   // Check if the edge into the successor block compares the __kmpc_target_init
2552   // result with -1. If we are in non-SPMD-mode that signals only the main
2553   // thread will execute the edge.
2554   auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
2555     if (!Edge || !Edge->isConditional())
2556       return false;
2557     if (Edge->getSuccessor(0) != SuccessorBB)
2558       return false;
2559 
2560     auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2561     if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2562       return false;
2563 
2564     ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2565     if (!C)
2566       return false;
2567 
2568     // Match:  -1 == __kmpc_target_init (for non-SPMD kernels only!)
2569     if (C->isAllOnesValue()) {
2570       auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2571       CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2572       if (!CB)
2573         return false;
2574       const int InitIsSPMDArgNo = 1;
2575       auto *IsSPMDModeCI =
2576           dyn_cast<ConstantInt>(CB->getOperand(InitIsSPMDArgNo));
2577       return IsSPMDModeCI && IsSPMDModeCI->isZero();
2578     }
2579 
2580     return false;
2581   };
2582 
2583   // Merge all the predecessor states into the current basic block. A basic
2584   // block is executed by a single thread if all of its predecessors are.
2585   auto MergePredecessorStates = [&](BasicBlock *BB) {
2586     if (pred_begin(BB) == pred_end(BB))
2587       return SingleThreadedBBs.contains(BB);
2588 
2589     bool IsInitialThread = true;
2590     for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB);
2591          PredBB != PredEndBB; ++PredBB) {
2592       if (!IsInitialThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()),
2593                                BB))
2594         IsInitialThread &= SingleThreadedBBs.contains(*PredBB);
2595     }
2596 
2597     return IsInitialThread;
2598   };
2599 
2600   for (auto *BB : RPOT) {
2601     if (!MergePredecessorStates(BB))
2602       SingleThreadedBBs.erase(BB);
2603   }
2604 
2605   return (NumSingleThreadedBBs == SingleThreadedBBs.size())
2606              ? ChangeStatus::UNCHANGED
2607              : ChangeStatus::CHANGED;
2608 }
2609 
2610 /// Try to replace memory allocation calls called by a single thread with a
2611 /// static buffer of shared memory.
2612 struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
2613   using Base = StateWrapper<BooleanState, AbstractAttribute>;
2614   AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2615 
2616   /// Create an abstract attribute view for the position \p IRP.
2617   static AAHeapToShared &createForPosition(const IRPosition &IRP,
2618                                            Attributor &A);
2619 
2620   /// Returns true if HeapToShared conversion is assumed to be possible.
2621   virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
2622 
2623   /// Returns true if HeapToShared conversion is assumed and the CB is a
2624   /// callsite to a free operation to be removed.
2625   virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
2626 
2627   /// See AbstractAttribute::getName().
2628   const std::string getName() const override { return "AAHeapToShared"; }
2629 
2630   /// See AbstractAttribute::getIdAddr().
2631   const char *getIdAddr() const override { return &ID; }
2632 
2633   /// This function should return true if the type of the \p AA is
2634   /// AAHeapToShared.
2635   static bool classof(const AbstractAttribute *AA) {
2636     return (AA->getIdAddr() == &ID);
2637   }
2638 
2639   /// Unique ID (due to the unique address)
2640   static const char ID;
2641 };
2642 
2643 struct AAHeapToSharedFunction : public AAHeapToShared {
2644   AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
2645       : AAHeapToShared(IRP, A) {}
2646 
2647   const std::string getAsStr() const override {
2648     return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
2649            " malloc calls eligible.";
2650   }
2651 
2652   /// See AbstractAttribute::trackStatistics().
2653   void trackStatistics() const override {}
2654 
2655   /// This functions finds free calls that will be removed by the
2656   /// HeapToShared transformation.
2657   void findPotentialRemovedFreeCalls(Attributor &A) {
2658     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2659     auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2660 
2661     PotentialRemovedFreeCalls.clear();
2662     // Update free call users of found malloc calls.
2663     for (CallBase *CB : MallocCalls) {
2664       SmallVector<CallBase *, 4> FreeCalls;
2665       for (auto *U : CB->users()) {
2666         CallBase *C = dyn_cast<CallBase>(U);
2667         if (C && C->getCalledFunction() == FreeRFI.Declaration)
2668           FreeCalls.push_back(C);
2669       }
2670 
2671       if (FreeCalls.size() != 1)
2672         continue;
2673 
2674       PotentialRemovedFreeCalls.insert(FreeCalls.front());
2675     }
2676   }
2677 
2678   void initialize(Attributor &A) override {
2679     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2680     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2681 
2682     for (User *U : RFI.Declaration->users())
2683       if (CallBase *CB = dyn_cast<CallBase>(U))
2684         MallocCalls.insert(CB);
2685 
2686     findPotentialRemovedFreeCalls(A);
2687   }
2688 
2689   bool isAssumedHeapToShared(CallBase &CB) const override {
2690     return isValidState() && MallocCalls.count(&CB);
2691   }
2692 
2693   bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
2694     return isValidState() && PotentialRemovedFreeCalls.count(&CB);
2695   }
2696 
2697   ChangeStatus manifest(Attributor &A) override {
2698     if (MallocCalls.empty())
2699       return ChangeStatus::UNCHANGED;
2700 
2701     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2702     auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2703 
2704     Function *F = getAnchorScope();
2705     auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
2706                                             DepClassTy::OPTIONAL);
2707 
2708     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2709     for (CallBase *CB : MallocCalls) {
2710       // Skip replacing this if HeapToStack has already claimed it.
2711       if (HS && HS->isAssumedHeapToStack(*CB))
2712         continue;
2713 
2714       // Find the unique free call to remove it.
2715       SmallVector<CallBase *, 4> FreeCalls;
2716       for (auto *U : CB->users()) {
2717         CallBase *C = dyn_cast<CallBase>(U);
2718         if (C && C->getCalledFunction() == FreeCall.Declaration)
2719           FreeCalls.push_back(C);
2720       }
2721       if (FreeCalls.size() != 1)
2722         continue;
2723 
2724       ConstantInt *AllocSize = dyn_cast<ConstantInt>(CB->getArgOperand(0));
2725 
2726       LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
2727                         << " with " << AllocSize->getZExtValue()
2728                         << " bytes of shared memory\n");
2729 
2730       // Create a new shared memory buffer of the same size as the allocation
2731       // and replace all the uses of the original allocation with it.
2732       Module *M = CB->getModule();
2733       Type *Int8Ty = Type::getInt8Ty(M->getContext());
2734       Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
2735       auto *SharedMem = new GlobalVariable(
2736           *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
2737           UndefValue::get(Int8ArrTy), CB->getName(), nullptr,
2738           GlobalValue::NotThreadLocal,
2739           static_cast<unsigned>(AddressSpace::Shared));
2740       auto *NewBuffer =
2741           ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
2742 
2743       auto Remark = [&](OptimizationRemark OR) {
2744         return OR << "Replaced globalized variable with "
2745                   << ore::NV("SharedMemory", AllocSize->getZExtValue())
2746                   << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ")
2747                   << "of shared memory.";
2748       };
2749       A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
2750 
2751       SharedMem->setAlignment(MaybeAlign(32));
2752 
2753       A.changeValueAfterManifest(*CB, *NewBuffer);
2754       A.deleteAfterManifest(*CB);
2755       A.deleteAfterManifest(*FreeCalls.front());
2756 
2757       NumBytesMovedToSharedMemory += AllocSize->getZExtValue();
2758       Changed = ChangeStatus::CHANGED;
2759     }
2760 
2761     return Changed;
2762   }
2763 
2764   ChangeStatus updateImpl(Attributor &A) override {
2765     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2766     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2767     Function *F = getAnchorScope();
2768 
2769     auto NumMallocCalls = MallocCalls.size();
2770 
2771     // Only consider malloc calls executed by a single thread with a constant.
2772     for (User *U : RFI.Declaration->users()) {
2773       const auto &ED = A.getAAFor<AAExecutionDomain>(
2774           *this, IRPosition::function(*F), DepClassTy::REQUIRED);
2775       if (CallBase *CB = dyn_cast<CallBase>(U))
2776         if (!dyn_cast<ConstantInt>(CB->getArgOperand(0)) ||
2777             !ED.isExecutedByInitialThreadOnly(*CB))
2778           MallocCalls.erase(CB);
2779     }
2780 
2781     findPotentialRemovedFreeCalls(A);
2782 
2783     if (NumMallocCalls != MallocCalls.size())
2784       return ChangeStatus::CHANGED;
2785 
2786     return ChangeStatus::UNCHANGED;
2787   }
2788 
2789   /// Collection of all malloc calls in a function.
2790   SmallPtrSet<CallBase *, 4> MallocCalls;
2791   /// Collection of potentially removed free calls in a function.
2792   SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
2793 };
2794 
2795 struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
2796   using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
2797   AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2798 
2799   /// Statistics are tracked as part of manifest for now.
2800   void trackStatistics() const override {}
2801 
2802   /// See AbstractAttribute::getAsStr()
2803   const std::string getAsStr() const override {
2804     if (!isValidState())
2805       return "<invalid>";
2806     return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
2807                                                             : "generic") +
2808            std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
2809                                                                : "") +
2810            std::string(" #PRs: ") +
2811            std::to_string(ReachedKnownParallelRegions.size()) +
2812            ", #Unknown PRs: " +
2813            std::to_string(ReachedUnknownParallelRegions.size());
2814   }
2815 
2816   /// Create an abstract attribute biew for the position \p IRP.
2817   static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
2818 
2819   /// See AbstractAttribute::getName()
2820   const std::string getName() const override { return "AAKernelInfo"; }
2821 
2822   /// See AbstractAttribute::getIdAddr()
2823   const char *getIdAddr() const override { return &ID; }
2824 
2825   /// This function should return true if the type of the \p AA is AAKernelInfo
2826   static bool classof(const AbstractAttribute *AA) {
2827     return (AA->getIdAddr() == &ID);
2828   }
2829 
2830   static const char ID;
2831 };
2832 
2833 /// The function kernel info abstract attribute, basically, what can we say
2834 /// about a function with regards to the KernelInfoState.
2835 struct AAKernelInfoFunction : AAKernelInfo {
2836   AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
2837       : AAKernelInfo(IRP, A) {}
2838 
2839   SmallPtrSet<Instruction *, 4> GuardedInstructions;
2840 
2841   SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
2842     return GuardedInstructions;
2843   }
2844 
2845   /// See AbstractAttribute::initialize(...).
2846   void initialize(Attributor &A) override {
2847     // This is a high-level transform that might change the constant arguments
2848     // of the init and dinit calls. We need to tell the Attributor about this
2849     // to avoid other parts using the current constant value for simpliication.
2850     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2851 
2852     Function *Fn = getAnchorScope();
2853     if (!OMPInfoCache.Kernels.count(Fn))
2854       return;
2855 
2856     // Add itself to the reaching kernel and set IsKernelEntry.
2857     ReachingKernelEntries.insert(Fn);
2858     IsKernelEntry = true;
2859 
2860     OMPInformationCache::RuntimeFunctionInfo &InitRFI =
2861         OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2862     OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
2863         OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
2864 
2865     // For kernels we perform more initialization work, first we find the init
2866     // and deinit calls.
2867     auto StoreCallBase = [](Use &U,
2868                             OMPInformationCache::RuntimeFunctionInfo &RFI,
2869                             CallBase *&Storage) {
2870       CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
2871       assert(CB &&
2872              "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
2873       assert(!Storage &&
2874              "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
2875       Storage = CB;
2876       return false;
2877     };
2878     InitRFI.foreachUse(
2879         [&](Use &U, Function &) {
2880           StoreCallBase(U, InitRFI, KernelInitCB);
2881           return false;
2882         },
2883         Fn);
2884     DeinitRFI.foreachUse(
2885         [&](Use &U, Function &) {
2886           StoreCallBase(U, DeinitRFI, KernelDeinitCB);
2887           return false;
2888         },
2889         Fn);
2890 
2891     // Ignore kernels without initializers such as global constructors.
2892     if (!KernelInitCB || !KernelDeinitCB) {
2893       indicateOptimisticFixpoint();
2894       return;
2895     }
2896 
2897     // For kernels we might need to initialize/finalize the IsSPMD state and
2898     // we need to register a simplification callback so that the Attributor
2899     // knows the constant arguments to __kmpc_target_init and
2900     // __kmpc_target_deinit might actually change.
2901 
2902     Attributor::SimplifictionCallbackTy StateMachineSimplifyCB =
2903         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2904             bool &UsedAssumedInformation) -> Optional<Value *> {
2905       // IRP represents the "use generic state machine" argument of an
2906       // __kmpc_target_init call. We will answer this one with the internal
2907       // state. As long as we are not in an invalid state, we will create a
2908       // custom state machine so the value should be a `i1 false`. If we are
2909       // in an invalid state, we won't change the value that is in the IR.
2910       if (!isValidState())
2911         return nullptr;
2912       // If we have disabled state machine rewrites, don't make a custom one.
2913       if (DisableOpenMPOptStateMachineRewrite)
2914         return nullptr;
2915       if (AA)
2916         A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2917       UsedAssumedInformation = !isAtFixpoint();
2918       auto *FalseVal =
2919           ConstantInt::getBool(IRP.getAnchorValue().getContext(), 0);
2920       return FalseVal;
2921     };
2922 
2923     Attributor::SimplifictionCallbackTy IsSPMDModeSimplifyCB =
2924         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2925             bool &UsedAssumedInformation) -> Optional<Value *> {
2926       // IRP represents the "SPMDCompatibilityTracker" argument of an
2927       // __kmpc_target_init or
2928       // __kmpc_target_deinit call. We will answer this one with the internal
2929       // state.
2930       if (!SPMDCompatibilityTracker.isValidState())
2931         return nullptr;
2932       if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2933         if (AA)
2934           A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2935         UsedAssumedInformation = true;
2936       } else {
2937         UsedAssumedInformation = false;
2938       }
2939       auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
2940                                        SPMDCompatibilityTracker.isAssumed());
2941       return Val;
2942     };
2943 
2944     Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB =
2945         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2946             bool &UsedAssumedInformation) -> Optional<Value *> {
2947       // IRP represents the "RequiresFullRuntime" argument of an
2948       // __kmpc_target_init or __kmpc_target_deinit call. We will answer this
2949       // one with the internal state of the SPMDCompatibilityTracker, so if
2950       // generic then true, if SPMD then false.
2951       if (!SPMDCompatibilityTracker.isValidState())
2952         return nullptr;
2953       if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2954         if (AA)
2955           A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2956         UsedAssumedInformation = true;
2957       } else {
2958         UsedAssumedInformation = false;
2959       }
2960       auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
2961                                        !SPMDCompatibilityTracker.isAssumed());
2962       return Val;
2963     };
2964 
2965     constexpr const int InitIsSPMDArgNo = 1;
2966     constexpr const int DeinitIsSPMDArgNo = 1;
2967     constexpr const int InitUseStateMachineArgNo = 2;
2968     constexpr const int InitRequiresFullRuntimeArgNo = 3;
2969     constexpr const int DeinitRequiresFullRuntimeArgNo = 2;
2970     A.registerSimplificationCallback(
2971         IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo),
2972         StateMachineSimplifyCB);
2973     A.registerSimplificationCallback(
2974         IRPosition::callsite_argument(*KernelInitCB, InitIsSPMDArgNo),
2975         IsSPMDModeSimplifyCB);
2976     A.registerSimplificationCallback(
2977         IRPosition::callsite_argument(*KernelDeinitCB, DeinitIsSPMDArgNo),
2978         IsSPMDModeSimplifyCB);
2979     A.registerSimplificationCallback(
2980         IRPosition::callsite_argument(*KernelInitCB,
2981                                       InitRequiresFullRuntimeArgNo),
2982         IsGenericModeSimplifyCB);
2983     A.registerSimplificationCallback(
2984         IRPosition::callsite_argument(*KernelDeinitCB,
2985                                       DeinitRequiresFullRuntimeArgNo),
2986         IsGenericModeSimplifyCB);
2987 
2988     // Check if we know we are in SPMD-mode already.
2989     ConstantInt *IsSPMDArg =
2990         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo));
2991     if (IsSPMDArg && !IsSPMDArg->isZero())
2992       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
2993     // This is a generic region but SPMDization is disabled so stop tracking.
2994     else if (DisableOpenMPOptSPMDization)
2995       SPMDCompatibilityTracker.indicatePessimisticFixpoint();
2996   }
2997 
2998   /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
2999   /// finished now.
3000   ChangeStatus manifest(Attributor &A) override {
3001     // If we are not looking at a kernel with __kmpc_target_init and
3002     // __kmpc_target_deinit call we cannot actually manifest the information.
3003     if (!KernelInitCB || !KernelDeinitCB)
3004       return ChangeStatus::UNCHANGED;
3005 
3006     // Known SPMD-mode kernels need no manifest changes.
3007     if (SPMDCompatibilityTracker.isKnown())
3008       return ChangeStatus::UNCHANGED;
3009 
3010     // If we can we change the execution mode to SPMD-mode otherwise we build a
3011     // custom state machine.
3012     if (!mayContainParallelRegion() || !changeToSPMDMode(A))
3013       buildCustomStateMachine(A);
3014 
3015     return ChangeStatus::CHANGED;
3016   }
3017 
3018   bool changeToSPMDMode(Attributor &A) {
3019     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3020 
3021     if (!SPMDCompatibilityTracker.isAssumed()) {
3022       for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
3023         if (!NonCompatibleI)
3024           continue;
3025 
3026         // Skip diagnostics on calls to known OpenMP runtime functions for now.
3027         if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
3028           if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
3029             continue;
3030 
3031         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3032           ORA << "Value has potential side effects preventing SPMD-mode "
3033                  "execution";
3034           if (isa<CallBase>(NonCompatibleI)) {
3035             ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
3036                    "the called function to override";
3037           }
3038           return ORA << ".";
3039         };
3040         A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
3041                                                  Remark);
3042 
3043         LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
3044                           << *NonCompatibleI << "\n");
3045       }
3046 
3047       return false;
3048     }
3049 
3050     auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3051                                    Instruction *RegionEndI) {
3052       LoopInfo *LI = nullptr;
3053       DominatorTree *DT = nullptr;
3054       MemorySSAUpdater *MSU = nullptr;
3055       using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3056 
3057       BasicBlock *ParentBB = RegionStartI->getParent();
3058       Function *Fn = ParentBB->getParent();
3059       Module &M = *Fn->getParent();
3060 
3061       // Create all the blocks and logic.
3062       // ParentBB:
3063       //    goto RegionCheckTidBB
3064       // RegionCheckTidBB:
3065       //    Tid = __kmpc_hardware_thread_id()
3066       //    if (Tid != 0)
3067       //        goto RegionBarrierBB
3068       // RegionStartBB:
3069       //    <execute instructions guarded>
3070       //    goto RegionEndBB
3071       // RegionEndBB:
3072       //    <store escaping values to shared mem>
3073       //    goto RegionBarrierBB
3074       //  RegionBarrierBB:
3075       //    __kmpc_simple_barrier_spmd()
3076       //    // second barrier is omitted if lacking escaping values.
3077       //    <load escaping values from shared mem>
3078       //    __kmpc_simple_barrier_spmd()
3079       //    goto RegionExitBB
3080       // RegionExitBB:
3081       //    <execute rest of instructions>
3082 
3083       BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3084                                            DT, LI, MSU, "region.guarded.end");
3085       BasicBlock *RegionBarrierBB =
3086           SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3087                      MSU, "region.barrier");
3088       BasicBlock *RegionExitBB =
3089           SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3090                      DT, LI, MSU, "region.exit");
3091       BasicBlock *RegionStartBB =
3092           SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3093 
3094       assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3095              "Expected a different CFG");
3096 
3097       BasicBlock *RegionCheckTidBB = SplitBlock(
3098           ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3099 
3100       // Register basic blocks with the Attributor.
3101       A.registerManifestAddedBasicBlock(*RegionEndBB);
3102       A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3103       A.registerManifestAddedBasicBlock(*RegionExitBB);
3104       A.registerManifestAddedBasicBlock(*RegionStartBB);
3105       A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
3106 
3107       bool HasBroadcastValues = false;
3108       // Find escaping outputs from the guarded region to outside users and
3109       // broadcast their values to them.
3110       for (Instruction &I : *RegionStartBB) {
3111         SmallPtrSet<Instruction *, 4> OutsideUsers;
3112         for (User *Usr : I.users()) {
3113           Instruction &UsrI = *cast<Instruction>(Usr);
3114           if (UsrI.getParent() != RegionStartBB)
3115             OutsideUsers.insert(&UsrI);
3116         }
3117 
3118         if (OutsideUsers.empty())
3119           continue;
3120 
3121         HasBroadcastValues = true;
3122 
3123         // Emit a global variable in shared memory to store the broadcasted
3124         // value.
3125         auto *SharedMem = new GlobalVariable(
3126             M, I.getType(), /* IsConstant */ false,
3127             GlobalValue::InternalLinkage, UndefValue::get(I.getType()),
3128             I.getName() + ".guarded.output.alloc", nullptr,
3129             GlobalValue::NotThreadLocal,
3130             static_cast<unsigned>(AddressSpace::Shared));
3131 
3132         // Emit a store instruction to update the value.
3133         new StoreInst(&I, SharedMem, RegionEndBB->getTerminator());
3134 
3135         LoadInst *LoadI = new LoadInst(I.getType(), SharedMem,
3136                                        I.getName() + ".guarded.output.load",
3137                                        RegionBarrierBB->getTerminator());
3138 
3139         // Emit a load instruction and replace uses of the output value.
3140         for (Instruction *UsrI : OutsideUsers) {
3141           assert(UsrI->getParent() == RegionExitBB &&
3142                  "Expected escaping users in exit region");
3143           UsrI->replaceUsesOfWith(&I, LoadI);
3144         }
3145       }
3146 
3147       auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3148 
3149       // Go to tid check BB in ParentBB.
3150       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
3151       ParentBB->getTerminator()->eraseFromParent();
3152       OpenMPIRBuilder::LocationDescription Loc(
3153           InsertPointTy(ParentBB, ParentBB->end()), DL);
3154       OMPInfoCache.OMPBuilder.updateToLocation(Loc);
3155       auto *SrcLocStr = OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc);
3156       Value *Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr);
3157       BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
3158 
3159       // Add check for Tid in RegionCheckTidBB
3160       RegionCheckTidBB->getTerminator()->eraseFromParent();
3161       OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
3162           InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
3163       OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
3164       FunctionCallee HardwareTidFn =
3165           OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3166               M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3167       Value *Tid =
3168           OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
3169       Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
3170       OMPInfoCache.OMPBuilder.Builder
3171           .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
3172           ->setDebugLoc(DL);
3173 
3174       // First barrier for synchronization, ensures main thread has updated
3175       // values.
3176       FunctionCallee BarrierFn =
3177           OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3178               M, OMPRTL___kmpc_barrier_simple_spmd);
3179       OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
3180           RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
3181       OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid})
3182           ->setDebugLoc(DL);
3183 
3184       // Second barrier ensures workers have read broadcast values.
3185       if (HasBroadcastValues)
3186         CallInst::Create(BarrierFn, {Ident, Tid}, "",
3187                          RegionBarrierBB->getTerminator())
3188             ->setDebugLoc(DL);
3189     };
3190 
3191     SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions;
3192 
3193     for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3194       BasicBlock *BB = GuardedI->getParent();
3195       auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
3196           IRPosition::function(*GuardedI->getFunction()), nullptr,
3197           DepClassTy::NONE);
3198       assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
3199       auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
3200       // Continue if instruction is already guarded.
3201       if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
3202         continue;
3203 
3204       Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
3205       for (Instruction &I : *BB) {
3206         // If instruction I needs to be guarded update the guarded region
3207         // bounds.
3208         if (SPMDCompatibilityTracker.contains(&I)) {
3209           CalleeAAFunction.getGuardedInstructions().insert(&I);
3210           if (GuardedRegionStart)
3211             GuardedRegionEnd = &I;
3212           else
3213             GuardedRegionStart = GuardedRegionEnd = &I;
3214 
3215           continue;
3216         }
3217 
3218         // Instruction I does not need guarding, store
3219         // any region found and reset bounds.
3220         if (GuardedRegionStart) {
3221           GuardedRegions.push_back(
3222               std::make_pair(GuardedRegionStart, GuardedRegionEnd));
3223           GuardedRegionStart = nullptr;
3224           GuardedRegionEnd = nullptr;
3225         }
3226       }
3227     }
3228 
3229     for (auto &GR : GuardedRegions)
3230       CreateGuardedRegion(GR.first, GR.second);
3231 
3232     // Adjust the global exec mode flag that tells the runtime what mode this
3233     // kernel is executed in.
3234     Function *Kernel = getAnchorScope();
3235     GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable(
3236         (Kernel->getName() + "_exec_mode").str());
3237     assert(ExecMode && "Kernel without exec mode?");
3238     assert(ExecMode->getInitializer() &&
3239            ExecMode->getInitializer()->isOneValue() &&
3240            "Initially non-SPMD kernel has SPMD exec mode!");
3241 
3242     // Set the global exec mode flag to indicate SPMD-Generic mode.
3243     constexpr int SPMDGeneric = 2;
3244     if (!ExecMode->getInitializer()->isZeroValue())
3245       ExecMode->setInitializer(
3246           ConstantInt::get(ExecMode->getInitializer()->getType(), SPMDGeneric));
3247 
3248     // Next rewrite the init and deinit calls to indicate we use SPMD-mode now.
3249     const int InitIsSPMDArgNo = 1;
3250     const int DeinitIsSPMDArgNo = 1;
3251     const int InitUseStateMachineArgNo = 2;
3252     const int InitRequiresFullRuntimeArgNo = 3;
3253     const int DeinitRequiresFullRuntimeArgNo = 2;
3254 
3255     auto &Ctx = getAnchorValue().getContext();
3256     A.changeUseAfterManifest(KernelInitCB->getArgOperandUse(InitIsSPMDArgNo),
3257                              *ConstantInt::getBool(Ctx, 1));
3258     A.changeUseAfterManifest(
3259         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
3260         *ConstantInt::getBool(Ctx, 0));
3261     A.changeUseAfterManifest(
3262         KernelDeinitCB->getArgOperandUse(DeinitIsSPMDArgNo),
3263         *ConstantInt::getBool(Ctx, 1));
3264     A.changeUseAfterManifest(
3265         KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo),
3266         *ConstantInt::getBool(Ctx, 0));
3267     A.changeUseAfterManifest(
3268         KernelDeinitCB->getArgOperandUse(DeinitRequiresFullRuntimeArgNo),
3269         *ConstantInt::getBool(Ctx, 0));
3270 
3271     ++NumOpenMPTargetRegionKernelsSPMD;
3272 
3273     auto Remark = [&](OptimizationRemark OR) {
3274       return OR << "Transformed generic-mode kernel to SPMD-mode.";
3275     };
3276     A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
3277     return true;
3278   };
3279 
3280   ChangeStatus buildCustomStateMachine(Attributor &A) {
3281     // If we have disabled state machine rewrites, don't make a custom one
3282     if (DisableOpenMPOptStateMachineRewrite)
3283       return indicatePessimisticFixpoint();
3284 
3285     assert(ReachedKnownParallelRegions.isValidState() &&
3286            "Custom state machine with invalid parallel region states?");
3287 
3288     const int InitIsSPMDArgNo = 1;
3289     const int InitUseStateMachineArgNo = 2;
3290 
3291     // Check if the current configuration is non-SPMD and generic state machine.
3292     // If we already have SPMD mode or a custom state machine we do not need to
3293     // go any further. If it is anything but a constant something is weird and
3294     // we give up.
3295     ConstantInt *UseStateMachine = dyn_cast<ConstantInt>(
3296         KernelInitCB->getArgOperand(InitUseStateMachineArgNo));
3297     ConstantInt *IsSPMD =
3298         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo));
3299 
3300     // If we are stuck with generic mode, try to create a custom device (=GPU)
3301     // state machine which is specialized for the parallel regions that are
3302     // reachable by the kernel.
3303     if (!UseStateMachine || UseStateMachine->isZero() || !IsSPMD ||
3304         !IsSPMD->isZero())
3305       return ChangeStatus::UNCHANGED;
3306 
3307     // If not SPMD mode, indicate we use a custom state machine now.
3308     auto &Ctx = getAnchorValue().getContext();
3309     auto *FalseVal = ConstantInt::getBool(Ctx, 0);
3310     A.changeUseAfterManifest(
3311         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal);
3312 
3313     // If we don't actually need a state machine we are done here. This can
3314     // happen if there simply are no parallel regions. In the resulting kernel
3315     // all worker threads will simply exit right away, leaving the main thread
3316     // to do the work alone.
3317     if (!mayContainParallelRegion()) {
3318       ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
3319 
3320       auto Remark = [&](OptimizationRemark OR) {
3321         return OR << "Removing unused state machine from generic-mode kernel.";
3322       };
3323       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
3324 
3325       return ChangeStatus::CHANGED;
3326     }
3327 
3328     // Keep track in the statistics of our new shiny custom state machine.
3329     if (ReachedUnknownParallelRegions.empty()) {
3330       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
3331 
3332       auto Remark = [&](OptimizationRemark OR) {
3333         return OR << "Rewriting generic-mode kernel with a customized state "
3334                      "machine.";
3335       };
3336       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
3337     } else {
3338       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
3339 
3340       auto Remark = [&](OptimizationRemarkAnalysis OR) {
3341         return OR << "Generic-mode kernel is executed with a customized state "
3342                      "machine that requires a fallback.";
3343       };
3344       A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
3345 
3346       // Tell the user why we ended up with a fallback.
3347       for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
3348         if (!UnknownParallelRegionCB)
3349           continue;
3350         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3351           return ORA << "Call may contain unknown parallel regions. Use "
3352                      << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
3353                         "override.";
3354         };
3355         A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
3356                                                  "OMP133", Remark);
3357       }
3358     }
3359 
3360     // Create all the blocks:
3361     //
3362     //                       InitCB = __kmpc_target_init(...)
3363     //                       bool IsWorker = InitCB >= 0;
3364     //                       if (IsWorker) {
3365     // SMBeginBB:               __kmpc_barrier_simple_spmd(...);
3366     //                         void *WorkFn;
3367     //                         bool Active = __kmpc_kernel_parallel(&WorkFn);
3368     //                         if (!WorkFn) return;
3369     // SMIsActiveCheckBB:       if (Active) {
3370     // SMIfCascadeCurrentBB:      if      (WorkFn == <ParFn0>)
3371     //                              ParFn0(...);
3372     // SMIfCascadeCurrentBB:      else if (WorkFn == <ParFn1>)
3373     //                              ParFn1(...);
3374     //                            ...
3375     // SMIfCascadeCurrentBB:      else
3376     //                              ((WorkFnTy*)WorkFn)(...);
3377     // SMEndParallelBB:           __kmpc_kernel_end_parallel(...);
3378     //                          }
3379     // SMDoneBB:                __kmpc_barrier_simple_spmd(...);
3380     //                          goto SMBeginBB;
3381     //                       }
3382     // UserCodeEntryBB:      // user code
3383     //                       __kmpc_target_deinit(...)
3384     //
3385     Function *Kernel = getAssociatedFunction();
3386     assert(Kernel && "Expected an associated function!");
3387 
3388     BasicBlock *InitBB = KernelInitCB->getParent();
3389     BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
3390         KernelInitCB->getNextNode(), "thread.user_code.check");
3391     BasicBlock *StateMachineBeginBB = BasicBlock::Create(
3392         Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
3393     BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
3394         Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
3395     BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
3396         Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
3397     BasicBlock *StateMachineIfCascadeCurrentBB =
3398         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3399                            Kernel, UserCodeEntryBB);
3400     BasicBlock *StateMachineEndParallelBB =
3401         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
3402                            Kernel, UserCodeEntryBB);
3403     BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
3404         Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
3405     A.registerManifestAddedBasicBlock(*InitBB);
3406     A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
3407     A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
3408     A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
3409     A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
3410     A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
3411     A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
3412     A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
3413 
3414     const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
3415     ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
3416 
3417     InitBB->getTerminator()->eraseFromParent();
3418     Instruction *IsWorker =
3419         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
3420                          ConstantInt::get(KernelInitCB->getType(), -1),
3421                          "thread.is_worker", InitBB);
3422     IsWorker->setDebugLoc(DLoc);
3423     BranchInst::Create(StateMachineBeginBB, UserCodeEntryBB, IsWorker, InitBB);
3424 
3425     // Create local storage for the work function pointer.
3426     Type *VoidPtrTy = Type::getInt8PtrTy(Ctx);
3427     AllocaInst *WorkFnAI = new AllocaInst(VoidPtrTy, 0, "worker.work_fn.addr",
3428                                           &Kernel->getEntryBlock().front());
3429     WorkFnAI->setDebugLoc(DLoc);
3430 
3431     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3432     OMPInfoCache.OMPBuilder.updateToLocation(
3433         OpenMPIRBuilder::LocationDescription(
3434             IRBuilder<>::InsertPoint(StateMachineBeginBB,
3435                                      StateMachineBeginBB->end()),
3436             DLoc));
3437 
3438     Value *Ident = KernelInitCB->getArgOperand(0);
3439     Value *GTid = KernelInitCB;
3440 
3441     Module &M = *Kernel->getParent();
3442     FunctionCallee BarrierFn =
3443         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3444             M, OMPRTL___kmpc_barrier_simple_spmd);
3445     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB)
3446         ->setDebugLoc(DLoc);
3447 
3448     FunctionCallee KernelParallelFn =
3449         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3450             M, OMPRTL___kmpc_kernel_parallel);
3451     Instruction *IsActiveWorker = CallInst::Create(
3452         KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
3453     IsActiveWorker->setDebugLoc(DLoc);
3454     Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
3455                                        StateMachineBeginBB);
3456     WorkFn->setDebugLoc(DLoc);
3457 
3458     FunctionType *ParallelRegionFnTy = FunctionType::get(
3459         Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
3460         false);
3461     Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast(
3462         WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast",
3463         StateMachineBeginBB);
3464 
3465     Instruction *IsDone =
3466         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
3467                          Constant::getNullValue(VoidPtrTy), "worker.is_done",
3468                          StateMachineBeginBB);
3469     IsDone->setDebugLoc(DLoc);
3470     BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
3471                        IsDone, StateMachineBeginBB)
3472         ->setDebugLoc(DLoc);
3473 
3474     BranchInst::Create(StateMachineIfCascadeCurrentBB,
3475                        StateMachineDoneBarrierBB, IsActiveWorker,
3476                        StateMachineIsActiveCheckBB)
3477         ->setDebugLoc(DLoc);
3478 
3479     Value *ZeroArg =
3480         Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
3481 
3482     // Now that we have most of the CFG skeleton it is time for the if-cascade
3483     // that checks the function pointer we got from the runtime against the
3484     // parallel regions we expect, if there are any.
3485     for (int i = 0, e = ReachedKnownParallelRegions.size(); i < e; ++i) {
3486       auto *ParallelRegion = ReachedKnownParallelRegions[i];
3487       BasicBlock *PRExecuteBB = BasicBlock::Create(
3488           Ctx, "worker_state_machine.parallel_region.execute", Kernel,
3489           StateMachineEndParallelBB);
3490       CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
3491           ->setDebugLoc(DLoc);
3492       BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
3493           ->setDebugLoc(DLoc);
3494 
3495       BasicBlock *PRNextBB =
3496           BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3497                              Kernel, StateMachineEndParallelBB);
3498 
3499       // Check if we need to compare the pointer at all or if we can just
3500       // call the parallel region function.
3501       Value *IsPR;
3502       if (i + 1 < e || !ReachedUnknownParallelRegions.empty()) {
3503         Instruction *CmpI = ICmpInst::Create(
3504             ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion,
3505             "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
3506         CmpI->setDebugLoc(DLoc);
3507         IsPR = CmpI;
3508       } else {
3509         IsPR = ConstantInt::getTrue(Ctx);
3510       }
3511 
3512       BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
3513                          StateMachineIfCascadeCurrentBB)
3514           ->setDebugLoc(DLoc);
3515       StateMachineIfCascadeCurrentBB = PRNextBB;
3516     }
3517 
3518     // At the end of the if-cascade we place the indirect function pointer call
3519     // in case we might need it, that is if there can be parallel regions we
3520     // have not handled in the if-cascade above.
3521     if (!ReachedUnknownParallelRegions.empty()) {
3522       StateMachineIfCascadeCurrentBB->setName(
3523           "worker_state_machine.parallel_region.fallback.execute");
3524       CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "",
3525                        StateMachineIfCascadeCurrentBB)
3526           ->setDebugLoc(DLoc);
3527     }
3528     BranchInst::Create(StateMachineEndParallelBB,
3529                        StateMachineIfCascadeCurrentBB)
3530         ->setDebugLoc(DLoc);
3531 
3532     CallInst::Create(OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3533                          M, OMPRTL___kmpc_kernel_end_parallel),
3534                      {}, "", StateMachineEndParallelBB)
3535         ->setDebugLoc(DLoc);
3536     BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
3537         ->setDebugLoc(DLoc);
3538 
3539     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
3540         ->setDebugLoc(DLoc);
3541     BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
3542         ->setDebugLoc(DLoc);
3543 
3544     return ChangeStatus::CHANGED;
3545   }
3546 
3547   /// Fixpoint iteration update function. Will be called every time a dependence
3548   /// changed its state (and in the beginning).
3549   ChangeStatus updateImpl(Attributor &A) override {
3550     KernelInfoState StateBefore = getState();
3551 
3552     // Callback to check a read/write instruction.
3553     auto CheckRWInst = [&](Instruction &I) {
3554       // We handle calls later.
3555       if (isa<CallBase>(I))
3556         return true;
3557       // We only care about write effects.
3558       if (!I.mayWriteToMemory())
3559         return true;
3560       if (auto *SI = dyn_cast<StoreInst>(&I)) {
3561         SmallVector<const Value *> Objects;
3562         getUnderlyingObjects(SI->getPointerOperand(), Objects);
3563         if (llvm::all_of(Objects,
3564                          [](const Value *Obj) { return isa<AllocaInst>(Obj); }))
3565           return true;
3566         // Check for AAHeapToStack moved objects which must not be guarded.
3567         auto &HS = A.getAAFor<AAHeapToStack>(
3568             *this, IRPosition::function(*I.getFunction()),
3569             DepClassTy::REQUIRED);
3570         if (llvm::all_of(Objects, [&HS](const Value *Obj) {
3571               auto *CB = dyn_cast<CallBase>(Obj);
3572               if (!CB)
3573                 return false;
3574               return HS.isAssumedHeapToStack(*CB);
3575             })) {
3576           return true;
3577         }
3578       }
3579 
3580       // Insert instruction that needs guarding.
3581       SPMDCompatibilityTracker.insert(&I);
3582       return true;
3583     };
3584 
3585     bool UsedAssumedInformationInCheckRWInst = false;
3586     if (!SPMDCompatibilityTracker.isAtFixpoint())
3587       if (!A.checkForAllReadWriteInstructions(
3588               CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
3589         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3590 
3591     if (!IsKernelEntry) {
3592       updateReachingKernelEntries(A);
3593       updateParallelLevels(A);
3594 
3595       if (!ParallelLevels.isValidState())
3596         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3597     }
3598 
3599     // Callback to check a call instruction.
3600     bool AllSPMDStatesWereFixed = true;
3601     auto CheckCallInst = [&](Instruction &I) {
3602       auto &CB = cast<CallBase>(I);
3603       auto &CBAA = A.getAAFor<AAKernelInfo>(
3604           *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
3605       getState() ^= CBAA.getState();
3606       AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint();
3607       return true;
3608     };
3609 
3610     bool UsedAssumedInformationInCheckCallInst = false;
3611     if (!A.checkForAllCallLikeInstructions(
3612             CheckCallInst, *this, UsedAssumedInformationInCheckCallInst))
3613       return indicatePessimisticFixpoint();
3614 
3615     // If we haven't used any assumed information for the SPMD state we can fix
3616     // it.
3617     if (!UsedAssumedInformationInCheckRWInst &&
3618         !UsedAssumedInformationInCheckCallInst && AllSPMDStatesWereFixed)
3619       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3620 
3621     return StateBefore == getState() ? ChangeStatus::UNCHANGED
3622                                      : ChangeStatus::CHANGED;
3623   }
3624 
3625 private:
3626   /// Update info regarding reaching kernels.
3627   void updateReachingKernelEntries(Attributor &A) {
3628     auto PredCallSite = [&](AbstractCallSite ACS) {
3629       Function *Caller = ACS.getInstruction()->getFunction();
3630 
3631       assert(Caller && "Caller is nullptr");
3632 
3633       auto &CAA = A.getOrCreateAAFor<AAKernelInfo>(
3634           IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
3635       if (CAA.ReachingKernelEntries.isValidState()) {
3636         ReachingKernelEntries ^= CAA.ReachingKernelEntries;
3637         return true;
3638       }
3639 
3640       // We lost track of the caller of the associated function, any kernel
3641       // could reach now.
3642       ReachingKernelEntries.indicatePessimisticFixpoint();
3643 
3644       return true;
3645     };
3646 
3647     bool AllCallSitesKnown;
3648     if (!A.checkForAllCallSites(PredCallSite, *this,
3649                                 true /* RequireAllCallSites */,
3650                                 AllCallSitesKnown))
3651       ReachingKernelEntries.indicatePessimisticFixpoint();
3652   }
3653 
3654   /// Update info regarding parallel levels.
3655   void updateParallelLevels(Attributor &A) {
3656     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3657     OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
3658         OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
3659 
3660     auto PredCallSite = [&](AbstractCallSite ACS) {
3661       Function *Caller = ACS.getInstruction()->getFunction();
3662 
3663       assert(Caller && "Caller is nullptr");
3664 
3665       auto &CAA =
3666           A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
3667       if (CAA.ParallelLevels.isValidState()) {
3668         // Any function that is called by `__kmpc_parallel_51` will not be
3669         // folded as the parallel level in the function is updated. In order to
3670         // get it right, all the analysis would depend on the implentation. That
3671         // said, if in the future any change to the implementation, the analysis
3672         // could be wrong. As a consequence, we are just conservative here.
3673         if (Caller == Parallel51RFI.Declaration) {
3674           ParallelLevels.indicatePessimisticFixpoint();
3675           return true;
3676         }
3677 
3678         ParallelLevels ^= CAA.ParallelLevels;
3679 
3680         return true;
3681       }
3682 
3683       // We lost track of the caller of the associated function, any kernel
3684       // could reach now.
3685       ParallelLevels.indicatePessimisticFixpoint();
3686 
3687       return true;
3688     };
3689 
3690     bool AllCallSitesKnown = true;
3691     if (!A.checkForAllCallSites(PredCallSite, *this,
3692                                 true /* RequireAllCallSites */,
3693                                 AllCallSitesKnown))
3694       ParallelLevels.indicatePessimisticFixpoint();
3695   }
3696 };
3697 
3698 /// The call site kernel info abstract attribute, basically, what can we say
3699 /// about a call site with regards to the KernelInfoState. For now this simply
3700 /// forwards the information from the callee.
3701 struct AAKernelInfoCallSite : AAKernelInfo {
3702   AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
3703       : AAKernelInfo(IRP, A) {}
3704 
3705   /// See AbstractAttribute::initialize(...).
3706   void initialize(Attributor &A) override {
3707     AAKernelInfo::initialize(A);
3708 
3709     CallBase &CB = cast<CallBase>(getAssociatedValue());
3710     Function *Callee = getAssociatedFunction();
3711 
3712     // Helper to lookup an assumption string.
3713     auto HasAssumption = [](Function *Fn, StringRef AssumptionStr) {
3714       return Fn && hasAssumption(*Fn, AssumptionStr);
3715     };
3716 
3717     // Check for SPMD-mode assumptions.
3718     if (HasAssumption(Callee, "ompx_spmd_amenable"))
3719       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3720 
3721     // First weed out calls we do not care about, that is readonly/readnone
3722     // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
3723     // parallel region or anything else we are looking for.
3724     if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
3725       indicateOptimisticFixpoint();
3726       return;
3727     }
3728 
3729     // Next we check if we know the callee. If it is a known OpenMP function
3730     // we will handle them explicitly in the switch below. If it is not, we
3731     // will use an AAKernelInfo object on the callee to gather information and
3732     // merge that into the current state. The latter happens in the updateImpl.
3733     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3734     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
3735     if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
3736       // Unknown caller or declarations are not analyzable, we give up.
3737       if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
3738 
3739         // Unknown callees might contain parallel regions, except if they have
3740         // an appropriate assumption attached.
3741         if (!(HasAssumption(Callee, "omp_no_openmp") ||
3742               HasAssumption(Callee, "omp_no_parallelism")))
3743           ReachedUnknownParallelRegions.insert(&CB);
3744 
3745         // If SPMDCompatibilityTracker is not fixed, we need to give up on the
3746         // idea we can run something unknown in SPMD-mode.
3747         if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3748           SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3749           SPMDCompatibilityTracker.insert(&CB);
3750         }
3751 
3752         // We have updated the state for this unknown call properly, there won't
3753         // be any change so we indicate a fixpoint.
3754         indicateOptimisticFixpoint();
3755       }
3756       // If the callee is known and can be used in IPO, we will update the state
3757       // based on the callee state in updateImpl.
3758       return;
3759     }
3760 
3761     const unsigned int WrapperFunctionArgNo = 6;
3762     RuntimeFunction RF = It->getSecond();
3763     switch (RF) {
3764     // All the functions we know are compatible with SPMD mode.
3765     case OMPRTL___kmpc_is_spmd_exec_mode:
3766     case OMPRTL___kmpc_for_static_fini:
3767     case OMPRTL___kmpc_global_thread_num:
3768     case OMPRTL___kmpc_get_hardware_num_threads_in_block:
3769     case OMPRTL___kmpc_get_hardware_num_blocks:
3770     case OMPRTL___kmpc_single:
3771     case OMPRTL___kmpc_end_single:
3772     case OMPRTL___kmpc_master:
3773     case OMPRTL___kmpc_end_master:
3774     case OMPRTL___kmpc_barrier:
3775       break;
3776     case OMPRTL___kmpc_for_static_init_4:
3777     case OMPRTL___kmpc_for_static_init_4u:
3778     case OMPRTL___kmpc_for_static_init_8:
3779     case OMPRTL___kmpc_for_static_init_8u: {
3780       // Check the schedule and allow static schedule in SPMD mode.
3781       unsigned ScheduleArgOpNo = 2;
3782       auto *ScheduleTypeCI =
3783           dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
3784       unsigned ScheduleTypeVal =
3785           ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
3786       switch (OMPScheduleType(ScheduleTypeVal)) {
3787       case OMPScheduleType::Static:
3788       case OMPScheduleType::StaticChunked:
3789       case OMPScheduleType::Distribute:
3790       case OMPScheduleType::DistributeChunked:
3791         break;
3792       default:
3793         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3794         SPMDCompatibilityTracker.insert(&CB);
3795         break;
3796       };
3797     } break;
3798     case OMPRTL___kmpc_target_init:
3799       KernelInitCB = &CB;
3800       break;
3801     case OMPRTL___kmpc_target_deinit:
3802       KernelDeinitCB = &CB;
3803       break;
3804     case OMPRTL___kmpc_parallel_51:
3805       if (auto *ParallelRegion = dyn_cast<Function>(
3806               CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) {
3807         ReachedKnownParallelRegions.insert(ParallelRegion);
3808         break;
3809       }
3810       // The condition above should usually get the parallel region function
3811       // pointer and record it. In the off chance it doesn't we assume the
3812       // worst.
3813       ReachedUnknownParallelRegions.insert(&CB);
3814       break;
3815     case OMPRTL___kmpc_omp_task:
3816       // We do not look into tasks right now, just give up.
3817       SPMDCompatibilityTracker.insert(&CB);
3818       ReachedUnknownParallelRegions.insert(&CB);
3819       indicatePessimisticFixpoint();
3820       return;
3821     case OMPRTL___kmpc_alloc_shared:
3822     case OMPRTL___kmpc_free_shared:
3823       // Return without setting a fixpoint, to be resolved in updateImpl.
3824       return;
3825     default:
3826       // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
3827       // generally.
3828       SPMDCompatibilityTracker.insert(&CB);
3829       indicatePessimisticFixpoint();
3830       return;
3831     }
3832     // All other OpenMP runtime calls will not reach parallel regions so they
3833     // can be safely ignored for now. Since it is a known OpenMP runtime call we
3834     // have now modeled all effects and there is no need for any update.
3835     indicateOptimisticFixpoint();
3836   }
3837 
3838   ChangeStatus updateImpl(Attributor &A) override {
3839     // TODO: Once we have call site specific value information we can provide
3840     //       call site specific liveness information and then it makes
3841     //       sense to specialize attributes for call sites arguments instead of
3842     //       redirecting requests to the callee argument.
3843     Function *F = getAssociatedFunction();
3844 
3845     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3846     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
3847 
3848     // If F is not a runtime function, propagate the AAKernelInfo of the callee.
3849     if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
3850       const IRPosition &FnPos = IRPosition::function(*F);
3851       auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
3852       if (getState() == FnAA.getState())
3853         return ChangeStatus::UNCHANGED;
3854       getState() = FnAA.getState();
3855       return ChangeStatus::CHANGED;
3856     }
3857 
3858     // F is a runtime function that allocates or frees memory, check
3859     // AAHeapToStack and AAHeapToShared.
3860     KernelInfoState StateBefore = getState();
3861     assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
3862             It->getSecond() == OMPRTL___kmpc_free_shared) &&
3863            "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
3864 
3865     CallBase &CB = cast<CallBase>(getAssociatedValue());
3866 
3867     auto &HeapToStackAA = A.getAAFor<AAHeapToStack>(
3868         *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
3869     auto &HeapToSharedAA = A.getAAFor<AAHeapToShared>(
3870         *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
3871 
3872     RuntimeFunction RF = It->getSecond();
3873 
3874     switch (RF) {
3875     // If neither HeapToStack nor HeapToShared assume the call is removed,
3876     // assume SPMD incompatibility.
3877     case OMPRTL___kmpc_alloc_shared:
3878       if (!HeapToStackAA.isAssumedHeapToStack(CB) &&
3879           !HeapToSharedAA.isAssumedHeapToShared(CB))
3880         SPMDCompatibilityTracker.insert(&CB);
3881       break;
3882     case OMPRTL___kmpc_free_shared:
3883       if (!HeapToStackAA.isAssumedHeapToStackRemovedFree(CB) &&
3884           !HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB))
3885         SPMDCompatibilityTracker.insert(&CB);
3886       break;
3887     default:
3888       SPMDCompatibilityTracker.insert(&CB);
3889     }
3890 
3891     return StateBefore == getState() ? ChangeStatus::UNCHANGED
3892                                      : ChangeStatus::CHANGED;
3893   }
3894 };
3895 
3896 struct AAFoldRuntimeCall
3897     : public StateWrapper<BooleanState, AbstractAttribute> {
3898   using Base = StateWrapper<BooleanState, AbstractAttribute>;
3899 
3900   AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3901 
3902   /// Statistics are tracked as part of manifest for now.
3903   void trackStatistics() const override {}
3904 
3905   /// Create an abstract attribute biew for the position \p IRP.
3906   static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
3907                                               Attributor &A);
3908 
3909   /// See AbstractAttribute::getName()
3910   const std::string getName() const override { return "AAFoldRuntimeCall"; }
3911 
3912   /// See AbstractAttribute::getIdAddr()
3913   const char *getIdAddr() const override { return &ID; }
3914 
3915   /// This function should return true if the type of the \p AA is
3916   /// AAFoldRuntimeCall
3917   static bool classof(const AbstractAttribute *AA) {
3918     return (AA->getIdAddr() == &ID);
3919   }
3920 
3921   static const char ID;
3922 };
3923 
3924 struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
3925   AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
3926       : AAFoldRuntimeCall(IRP, A) {}
3927 
3928   /// See AbstractAttribute::getAsStr()
3929   const std::string getAsStr() const override {
3930     if (!isValidState())
3931       return "<invalid>";
3932 
3933     std::string Str("simplified value: ");
3934 
3935     if (!SimplifiedValue.hasValue())
3936       return Str + std::string("none");
3937 
3938     if (!SimplifiedValue.getValue())
3939       return Str + std::string("nullptr");
3940 
3941     if (ConstantInt *CI = dyn_cast<ConstantInt>(SimplifiedValue.getValue()))
3942       return Str + std::to_string(CI->getSExtValue());
3943 
3944     return Str + std::string("unknown");
3945   }
3946 
3947   void initialize(Attributor &A) override {
3948     if (DisableOpenMPOptFolding)
3949       indicatePessimisticFixpoint();
3950 
3951     Function *Callee = getAssociatedFunction();
3952 
3953     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3954     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
3955     assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
3956            "Expected a known OpenMP runtime function");
3957 
3958     RFKind = It->getSecond();
3959 
3960     CallBase &CB = cast<CallBase>(getAssociatedValue());
3961     A.registerSimplificationCallback(
3962         IRPosition::callsite_returned(CB),
3963         [&](const IRPosition &IRP, const AbstractAttribute *AA,
3964             bool &UsedAssumedInformation) -> Optional<Value *> {
3965           assert((isValidState() || (SimplifiedValue.hasValue() &&
3966                                      SimplifiedValue.getValue() == nullptr)) &&
3967                  "Unexpected invalid state!");
3968 
3969           if (!isAtFixpoint()) {
3970             UsedAssumedInformation = true;
3971             if (AA)
3972               A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3973           }
3974           return SimplifiedValue;
3975         });
3976   }
3977 
3978   ChangeStatus updateImpl(Attributor &A) override {
3979     ChangeStatus Changed = ChangeStatus::UNCHANGED;
3980     switch (RFKind) {
3981     case OMPRTL___kmpc_is_spmd_exec_mode:
3982       Changed |= foldIsSPMDExecMode(A);
3983       break;
3984     case OMPRTL___kmpc_is_generic_main_thread_id:
3985       Changed |= foldIsGenericMainThread(A);
3986       break;
3987     case OMPRTL___kmpc_parallel_level:
3988       Changed |= foldParallelLevel(A);
3989       break;
3990     case OMPRTL___kmpc_get_hardware_num_threads_in_block:
3991       Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
3992       break;
3993     case OMPRTL___kmpc_get_hardware_num_blocks:
3994       Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
3995       break;
3996     default:
3997       llvm_unreachable("Unhandled OpenMP runtime function!");
3998     }
3999 
4000     return Changed;
4001   }
4002 
4003   ChangeStatus manifest(Attributor &A) override {
4004     ChangeStatus Changed = ChangeStatus::UNCHANGED;
4005 
4006     if (SimplifiedValue.hasValue() && SimplifiedValue.getValue()) {
4007       Instruction &CB = *getCtxI();
4008       A.changeValueAfterManifest(CB, **SimplifiedValue);
4009       A.deleteAfterManifest(CB);
4010 
4011       LLVM_DEBUG(dbgs() << TAG << "Folding runtime call: " << CB << " with "
4012                         << **SimplifiedValue << "\n");
4013 
4014       Changed = ChangeStatus::CHANGED;
4015     }
4016 
4017     return Changed;
4018   }
4019 
4020   ChangeStatus indicatePessimisticFixpoint() override {
4021     SimplifiedValue = nullptr;
4022     return AAFoldRuntimeCall::indicatePessimisticFixpoint();
4023   }
4024 
4025 private:
4026   /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
4027   ChangeStatus foldIsSPMDExecMode(Attributor &A) {
4028     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4029 
4030     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4031     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4032     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4033         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4034 
4035     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4036       return indicatePessimisticFixpoint();
4037 
4038     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4039       auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4040                                           DepClassTy::REQUIRED);
4041 
4042       if (!AA.isValidState()) {
4043         SimplifiedValue = nullptr;
4044         return indicatePessimisticFixpoint();
4045       }
4046 
4047       if (AA.SPMDCompatibilityTracker.isAssumed()) {
4048         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4049           ++KnownSPMDCount;
4050         else
4051           ++AssumedSPMDCount;
4052       } else {
4053         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4054           ++KnownNonSPMDCount;
4055         else
4056           ++AssumedNonSPMDCount;
4057       }
4058     }
4059 
4060     if ((AssumedSPMDCount + KnownSPMDCount) &&
4061         (AssumedNonSPMDCount + KnownNonSPMDCount))
4062       return indicatePessimisticFixpoint();
4063 
4064     auto &Ctx = getAnchorValue().getContext();
4065     if (KnownSPMDCount || AssumedSPMDCount) {
4066       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4067              "Expected only SPMD kernels!");
4068       // All reaching kernels are in SPMD mode. Update all function calls to
4069       // __kmpc_is_spmd_exec_mode to 1.
4070       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4071     } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
4072       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4073              "Expected only non-SPMD kernels!");
4074       // All reaching kernels are in non-SPMD mode. Update all function
4075       // calls to __kmpc_is_spmd_exec_mode to 0.
4076       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
4077     } else {
4078       // We have empty reaching kernels, therefore we cannot tell if the
4079       // associated call site can be folded. At this moment, SimplifiedValue
4080       // must be none.
4081       assert(!SimplifiedValue.hasValue() && "SimplifiedValue should be none");
4082     }
4083 
4084     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4085                                                     : ChangeStatus::CHANGED;
4086   }
4087 
4088   /// Fold __kmpc_is_generic_main_thread_id into a constant if possible.
4089   ChangeStatus foldIsGenericMainThread(Attributor &A) {
4090     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4091 
4092     CallBase &CB = cast<CallBase>(getAssociatedValue());
4093     Function *F = CB.getFunction();
4094     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
4095         *this, IRPosition::function(*F), DepClassTy::REQUIRED);
4096 
4097     if (!ExecutionDomainAA.isValidState())
4098       return indicatePessimisticFixpoint();
4099 
4100     auto &Ctx = getAnchorValue().getContext();
4101     if (ExecutionDomainAA.isExecutedByInitialThreadOnly(CB))
4102       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4103     else
4104       return indicatePessimisticFixpoint();
4105 
4106     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4107                                                     : ChangeStatus::CHANGED;
4108   }
4109 
4110   /// Fold __kmpc_parallel_level into a constant if possible.
4111   ChangeStatus foldParallelLevel(Attributor &A) {
4112     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4113 
4114     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4115         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4116 
4117     if (!CallerKernelInfoAA.ParallelLevels.isValidState())
4118       return indicatePessimisticFixpoint();
4119 
4120     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4121       return indicatePessimisticFixpoint();
4122 
4123     if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
4124       assert(!SimplifiedValue.hasValue() &&
4125              "SimplifiedValue should keep none at this point");
4126       return ChangeStatus::UNCHANGED;
4127     }
4128 
4129     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4130     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4131     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4132       auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4133                                           DepClassTy::REQUIRED);
4134       if (!AA.SPMDCompatibilityTracker.isValidState())
4135         return indicatePessimisticFixpoint();
4136 
4137       if (AA.SPMDCompatibilityTracker.isAssumed()) {
4138         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4139           ++KnownSPMDCount;
4140         else
4141           ++AssumedSPMDCount;
4142       } else {
4143         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4144           ++KnownNonSPMDCount;
4145         else
4146           ++AssumedNonSPMDCount;
4147       }
4148     }
4149 
4150     if ((AssumedSPMDCount + KnownSPMDCount) &&
4151         (AssumedNonSPMDCount + KnownNonSPMDCount))
4152       return indicatePessimisticFixpoint();
4153 
4154     auto &Ctx = getAnchorValue().getContext();
4155     // If the caller can only be reached by SPMD kernel entries, the parallel
4156     // level is 1. Similarly, if the caller can only be reached by non-SPMD
4157     // kernel entries, it is 0.
4158     if (AssumedSPMDCount || KnownSPMDCount) {
4159       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4160              "Expected only SPMD kernels!");
4161       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
4162     } else {
4163       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4164              "Expected only non-SPMD kernels!");
4165       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
4166     }
4167     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4168                                                     : ChangeStatus::CHANGED;
4169   }
4170 
4171   ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
4172     // Specialize only if all the calls agree with the attribute constant value
4173     int32_t CurrentAttrValue = -1;
4174     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4175 
4176     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4177         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4178 
4179     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4180       return indicatePessimisticFixpoint();
4181 
4182     // Iterate over the kernels that reach this function
4183     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4184       int32_t NextAttrVal = -1;
4185       if (K->hasFnAttribute(Attr))
4186         NextAttrVal =
4187             std::stoi(K->getFnAttribute(Attr).getValueAsString().str());
4188 
4189       if (NextAttrVal == -1 ||
4190           (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
4191         return indicatePessimisticFixpoint();
4192       CurrentAttrValue = NextAttrVal;
4193     }
4194 
4195     if (CurrentAttrValue != -1) {
4196       auto &Ctx = getAnchorValue().getContext();
4197       SimplifiedValue =
4198           ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
4199     }
4200     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4201                                                     : ChangeStatus::CHANGED;
4202   }
4203 
4204   /// An optional value the associated value is assumed to fold to. That is, we
4205   /// assume the associated value (which is a call) can be replaced by this
4206   /// simplified value.
4207   Optional<Value *> SimplifiedValue;
4208 
4209   /// The runtime function kind of the callee of the associated call site.
4210   RuntimeFunction RFKind;
4211 };
4212 
4213 } // namespace
4214 
4215 /// Register folding callsite
4216 void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
4217   auto &RFI = OMPInfoCache.RFIs[RF];
4218   RFI.foreachUse(SCC, [&](Use &U, Function &F) {
4219     CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
4220     if (!CI)
4221       return false;
4222     A.getOrCreateAAFor<AAFoldRuntimeCall>(
4223         IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
4224         DepClassTy::NONE, /* ForceUpdate */ false,
4225         /* UpdateAfterInit */ false);
4226     return false;
4227   });
4228 }
4229 
4230 void OpenMPOpt::registerAAs(bool IsModulePass) {
4231   if (SCC.empty())
4232 
4233     return;
4234   if (IsModulePass) {
4235     // Ensure we create the AAKernelInfo AAs first and without triggering an
4236     // update. This will make sure we register all value simplification
4237     // callbacks before any other AA has the chance to create an AAValueSimplify
4238     // or similar.
4239     for (Function *Kernel : OMPInfoCache.Kernels)
4240       A.getOrCreateAAFor<AAKernelInfo>(
4241           IRPosition::function(*Kernel), /* QueryingAA */ nullptr,
4242           DepClassTy::NONE, /* ForceUpdate */ false,
4243           /* UpdateAfterInit */ false);
4244 
4245 
4246     registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id);
4247     registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
4248     registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
4249     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
4250     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
4251   }
4252 
4253   // Create CallSite AA for all Getters.
4254   for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
4255     auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
4256 
4257     auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
4258 
4259     auto CreateAA = [&](Use &U, Function &Caller) {
4260       CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
4261       if (!CI)
4262         return false;
4263 
4264       auto &CB = cast<CallBase>(*CI);
4265 
4266       IRPosition CBPos = IRPosition::callsite_function(CB);
4267       A.getOrCreateAAFor<AAICVTracker>(CBPos);
4268       return false;
4269     };
4270 
4271     GetterRFI.foreachUse(SCC, CreateAA);
4272   }
4273   auto &GlobalizationRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4274   auto CreateAA = [&](Use &U, Function &F) {
4275     A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
4276     return false;
4277   };
4278   if (!DisableOpenMPOptDeglobalization)
4279     GlobalizationRFI.foreachUse(SCC, CreateAA);
4280 
4281   // Create an ExecutionDomain AA for every function and a HeapToStack AA for
4282   // every function if there is a device kernel.
4283   if (!isOpenMPDevice(M))
4284     return;
4285 
4286   for (auto *F : SCC) {
4287     if (F->isDeclaration())
4288       continue;
4289 
4290     A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F));
4291     if (!DisableOpenMPOptDeglobalization)
4292       A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F));
4293 
4294     for (auto &I : instructions(*F)) {
4295       if (auto *LI = dyn_cast<LoadInst>(&I)) {
4296         bool UsedAssumedInformation = false;
4297         A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
4298                                UsedAssumedInformation);
4299       }
4300     }
4301   }
4302 }
4303 
4304 const char AAICVTracker::ID = 0;
4305 const char AAKernelInfo::ID = 0;
4306 const char AAExecutionDomain::ID = 0;
4307 const char AAHeapToShared::ID = 0;
4308 const char AAFoldRuntimeCall::ID = 0;
4309 
4310 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
4311                                               Attributor &A) {
4312   AAICVTracker *AA = nullptr;
4313   switch (IRP.getPositionKind()) {
4314   case IRPosition::IRP_INVALID:
4315   case IRPosition::IRP_FLOAT:
4316   case IRPosition::IRP_ARGUMENT:
4317   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4318     llvm_unreachable("ICVTracker can only be created for function position!");
4319   case IRPosition::IRP_RETURNED:
4320     AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
4321     break;
4322   case IRPosition::IRP_CALL_SITE_RETURNED:
4323     AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
4324     break;
4325   case IRPosition::IRP_CALL_SITE:
4326     AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
4327     break;
4328   case IRPosition::IRP_FUNCTION:
4329     AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
4330     break;
4331   }
4332 
4333   return *AA;
4334 }
4335 
4336 AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP,
4337                                                         Attributor &A) {
4338   AAExecutionDomainFunction *AA = nullptr;
4339   switch (IRP.getPositionKind()) {
4340   case IRPosition::IRP_INVALID:
4341   case IRPosition::IRP_FLOAT:
4342   case IRPosition::IRP_ARGUMENT:
4343   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4344   case IRPosition::IRP_RETURNED:
4345   case IRPosition::IRP_CALL_SITE_RETURNED:
4346   case IRPosition::IRP_CALL_SITE:
4347     llvm_unreachable(
4348         "AAExecutionDomain can only be created for function position!");
4349   case IRPosition::IRP_FUNCTION:
4350     AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
4351     break;
4352   }
4353 
4354   return *AA;
4355 }
4356 
4357 AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
4358                                                   Attributor &A) {
4359   AAHeapToSharedFunction *AA = nullptr;
4360   switch (IRP.getPositionKind()) {
4361   case IRPosition::IRP_INVALID:
4362   case IRPosition::IRP_FLOAT:
4363   case IRPosition::IRP_ARGUMENT:
4364   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4365   case IRPosition::IRP_RETURNED:
4366   case IRPosition::IRP_CALL_SITE_RETURNED:
4367   case IRPosition::IRP_CALL_SITE:
4368     llvm_unreachable(
4369         "AAHeapToShared can only be created for function position!");
4370   case IRPosition::IRP_FUNCTION:
4371     AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
4372     break;
4373   }
4374 
4375   return *AA;
4376 }
4377 
4378 AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
4379                                               Attributor &A) {
4380   AAKernelInfo *AA = nullptr;
4381   switch (IRP.getPositionKind()) {
4382   case IRPosition::IRP_INVALID:
4383   case IRPosition::IRP_FLOAT:
4384   case IRPosition::IRP_ARGUMENT:
4385   case IRPosition::IRP_RETURNED:
4386   case IRPosition::IRP_CALL_SITE_RETURNED:
4387   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4388     llvm_unreachable("KernelInfo can only be created for function position!");
4389   case IRPosition::IRP_CALL_SITE:
4390     AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
4391     break;
4392   case IRPosition::IRP_FUNCTION:
4393     AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
4394     break;
4395   }
4396 
4397   return *AA;
4398 }
4399 
4400 AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
4401                                                         Attributor &A) {
4402   AAFoldRuntimeCall *AA = nullptr;
4403   switch (IRP.getPositionKind()) {
4404   case IRPosition::IRP_INVALID:
4405   case IRPosition::IRP_FLOAT:
4406   case IRPosition::IRP_ARGUMENT:
4407   case IRPosition::IRP_RETURNED:
4408   case IRPosition::IRP_FUNCTION:
4409   case IRPosition::IRP_CALL_SITE:
4410   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4411     llvm_unreachable("KernelInfo can only be created for call site position!");
4412   case IRPosition::IRP_CALL_SITE_RETURNED:
4413     AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
4414     break;
4415   }
4416 
4417   return *AA;
4418 }
4419 
4420 PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
4421   if (!containsOpenMP(M))
4422     return PreservedAnalyses::all();
4423   if (DisableOpenMPOptimizations)
4424     return PreservedAnalyses::all();
4425 
4426   FunctionAnalysisManager &FAM =
4427       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
4428   KernelSet Kernels = getDeviceKernels(M);
4429 
4430   auto IsCalled = [&](Function &F) {
4431     if (Kernels.contains(&F))
4432       return true;
4433     for (const User *U : F.users())
4434       if (!isa<BlockAddress>(U))
4435         return true;
4436     return false;
4437   };
4438 
4439   auto EmitRemark = [&](Function &F) {
4440     auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
4441     ORE.emit([&]() {
4442       OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
4443       return ORA << "Could not internalize function. "
4444                  << "Some optimizations may not be possible. [OMP140]";
4445     });
4446   };
4447 
4448   // Create internal copies of each function if this is a kernel Module. This
4449   // allows iterprocedural passes to see every call edge.
4450   DenseMap<Function *, Function *> InternalizedMap;
4451   if (isOpenMPDevice(M)) {
4452     SmallPtrSet<Function *, 16> InternalizeFns;
4453     for (Function &F : M)
4454       if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
4455           !DisableInternalization) {
4456         if (Attributor::isInternalizable(F)) {
4457           InternalizeFns.insert(&F);
4458         } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
4459           EmitRemark(F);
4460         }
4461       }
4462 
4463     Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
4464   }
4465 
4466   // Look at every function in the Module unless it was internalized.
4467   SmallVector<Function *, 16> SCC;
4468   for (Function &F : M)
4469     if (!F.isDeclaration() && !InternalizedMap.lookup(&F))
4470       SCC.push_back(&F);
4471 
4472   if (SCC.empty())
4473     return PreservedAnalyses::all();
4474 
4475   AnalysisGetter AG(FAM);
4476 
4477   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
4478     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
4479   };
4480 
4481   BumpPtrAllocator Allocator;
4482   CallGraphUpdater CGUpdater;
4483 
4484   SetVector<Function *> Functions(SCC.begin(), SCC.end());
4485   OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels);
4486 
4487   unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4488   Attributor A(Functions, InfoCache, CGUpdater, nullptr, true, false,
4489                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4490 
4491   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4492   bool Changed = OMPOpt.run(true);
4493 
4494   // Optionally inline device functions for potentially better performance.
4495   if (AlwaysInlineDeviceFunctions && isOpenMPDevice(M))
4496     for (Function &F : M)
4497       if (!F.isDeclaration() && !Kernels.contains(&F) &&
4498           !F.hasFnAttribute(Attribute::NoInline))
4499         F.addFnAttr(Attribute::AlwaysInline);
4500 
4501   if (PrintModuleAfterOptimizations)
4502     LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
4503 
4504   if (Changed)
4505     return PreservedAnalyses::none();
4506 
4507   return PreservedAnalyses::all();
4508 }
4509 
4510 PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
4511                                           CGSCCAnalysisManager &AM,
4512                                           LazyCallGraph &CG,
4513                                           CGSCCUpdateResult &UR) {
4514   if (!containsOpenMP(*C.begin()->getFunction().getParent()))
4515     return PreservedAnalyses::all();
4516   if (DisableOpenMPOptimizations)
4517     return PreservedAnalyses::all();
4518 
4519   SmallVector<Function *, 16> SCC;
4520   // If there are kernels in the module, we have to run on all SCC's.
4521   for (LazyCallGraph::Node &N : C) {
4522     Function *Fn = &N.getFunction();
4523     SCC.push_back(Fn);
4524   }
4525 
4526   if (SCC.empty())
4527     return PreservedAnalyses::all();
4528 
4529   Module &M = *C.begin()->getFunction().getParent();
4530 
4531   KernelSet Kernels = getDeviceKernels(M);
4532 
4533   FunctionAnalysisManager &FAM =
4534       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
4535 
4536   AnalysisGetter AG(FAM);
4537 
4538   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
4539     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
4540   };
4541 
4542   BumpPtrAllocator Allocator;
4543   CallGraphUpdater CGUpdater;
4544   CGUpdater.initialize(CG, C, AM, UR);
4545 
4546   SetVector<Function *> Functions(SCC.begin(), SCC.end());
4547   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
4548                                 /*CGSCC*/ Functions, Kernels);
4549 
4550   unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4551   Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
4552                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4553 
4554   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4555   bool Changed = OMPOpt.run(false);
4556 
4557   if (PrintModuleAfterOptimizations)
4558     LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
4559 
4560   if (Changed)
4561     return PreservedAnalyses::none();
4562 
4563   return PreservedAnalyses::all();
4564 }
4565 
4566 namespace {
4567 
4568 struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass {
4569   CallGraphUpdater CGUpdater;
4570   static char ID;
4571 
4572   OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) {
4573     initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry());
4574   }
4575 
4576   void getAnalysisUsage(AnalysisUsage &AU) const override {
4577     CallGraphSCCPass::getAnalysisUsage(AU);
4578   }
4579 
4580   bool runOnSCC(CallGraphSCC &CGSCC) override {
4581     if (!containsOpenMP(CGSCC.getCallGraph().getModule()))
4582       return false;
4583     if (DisableOpenMPOptimizations || skipSCC(CGSCC))
4584       return false;
4585 
4586     SmallVector<Function *, 16> SCC;
4587     // If there are kernels in the module, we have to run on all SCC's.
4588     for (CallGraphNode *CGN : CGSCC) {
4589       Function *Fn = CGN->getFunction();
4590       if (!Fn || Fn->isDeclaration())
4591         continue;
4592       SCC.push_back(Fn);
4593     }
4594 
4595     if (SCC.empty())
4596       return false;
4597 
4598     Module &M = CGSCC.getCallGraph().getModule();
4599     KernelSet Kernels = getDeviceKernels(M);
4600 
4601     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
4602     CGUpdater.initialize(CG, CGSCC);
4603 
4604     // Maintain a map of functions to avoid rebuilding the ORE
4605     DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap;
4606     auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
4607       std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
4608       if (!ORE)
4609         ORE = std::make_unique<OptimizationRemarkEmitter>(F);
4610       return *ORE;
4611     };
4612 
4613     AnalysisGetter AG;
4614     SetVector<Function *> Functions(SCC.begin(), SCC.end());
4615     BumpPtrAllocator Allocator;
4616     OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG,
4617                                   Allocator,
4618                                   /*CGSCC*/ Functions, Kernels);
4619 
4620     unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4621     Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
4622                  MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4623 
4624     OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4625     bool Result = OMPOpt.run(false);
4626 
4627     if (PrintModuleAfterOptimizations)
4628       LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
4629 
4630     return Result;
4631   }
4632 
4633   bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
4634 };
4635 
4636 } // end anonymous namespace
4637 
4638 KernelSet llvm::omp::getDeviceKernels(Module &M) {
4639   // TODO: Create a more cross-platform way of determining device kernels.
4640   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
4641   KernelSet Kernels;
4642 
4643   if (!MD)
4644     return Kernels;
4645 
4646   for (auto *Op : MD->operands()) {
4647     if (Op->getNumOperands() < 2)
4648       continue;
4649     MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
4650     if (!KindID || KindID->getString() != "kernel")
4651       continue;
4652 
4653     Function *KernelFn =
4654         mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
4655     if (!KernelFn)
4656       continue;
4657 
4658     ++NumOpenMPTargetRegionKernels;
4659 
4660     Kernels.insert(KernelFn);
4661   }
4662 
4663   return Kernels;
4664 }
4665 
4666 bool llvm::omp::containsOpenMP(Module &M) {
4667   Metadata *MD = M.getModuleFlag("openmp");
4668   if (!MD)
4669     return false;
4670 
4671   return true;
4672 }
4673 
4674 bool llvm::omp::isOpenMPDevice(Module &M) {
4675   Metadata *MD = M.getModuleFlag("openmp-device");
4676   if (!MD)
4677     return false;
4678 
4679   return true;
4680 }
4681 
4682 char OpenMPOptCGSCCLegacyPass::ID = 0;
4683 
4684 INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
4685                       "OpenMP specific optimizations", false, false)
4686 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
4687 INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
4688                     "OpenMP specific optimizations", false, false)
4689 
4690 Pass *llvm::createOpenMPOptCGSCCLegacyPass() {
4691   return new OpenMPOptCGSCCLegacyPass();
4692 }
4693