1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpenMP specific optimizations:
10 //
11 // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12 // - Replacing globalized device memory with stack memory.
13 // - Replacing globalized device memory with shared memory.
14 // - Parallel region merging.
15 // - Transforming generic-mode device kernels to SPMD mode.
16 // - Specializing the state machine for generic-mode device kernels.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "llvm/Transforms/IPO/OpenMPOpt.h"
21 
22 #include "llvm/ADT/EnumeratedArray.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/Statistic.h"
25 #include "llvm/Analysis/CallGraph.h"
26 #include "llvm/Analysis/CallGraphSCCPass.h"
27 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
28 #include "llvm/Analysis/ValueTracking.h"
29 #include "llvm/Frontend/OpenMP/OMPConstants.h"
30 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
31 #include "llvm/IR/Assumptions.h"
32 #include "llvm/IR/DiagnosticInfo.h"
33 #include "llvm/IR/GlobalValue.h"
34 #include "llvm/IR/Instruction.h"
35 #include "llvm/IR/IntrinsicInst.h"
36 #include "llvm/InitializePasses.h"
37 #include "llvm/Support/CommandLine.h"
38 #include "llvm/Transforms/IPO.h"
39 #include "llvm/Transforms/IPO/Attributor.h"
40 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
41 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
42 #include "llvm/Transforms/Utils/CodeExtractor.h"
43 
44 using namespace llvm;
45 using namespace omp;
46 
47 #define DEBUG_TYPE "openmp-opt"
48 
49 static cl::opt<bool> DisableOpenMPOptimizations(
50     "openmp-opt-disable", cl::ZeroOrMore,
51     cl::desc("Disable OpenMP specific optimizations."), cl::Hidden,
52     cl::init(false));
53 
54 static cl::opt<bool> EnableParallelRegionMerging(
55     "openmp-opt-enable-merging", cl::ZeroOrMore,
56     cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
57     cl::init(false));
58 
59 static cl::opt<bool>
60     DisableInternalization("openmp-opt-disable-internalization", cl::ZeroOrMore,
61                            cl::desc("Disable function internalization."),
62                            cl::Hidden, cl::init(false));
63 
64 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
65                                     cl::Hidden);
66 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
67                                         cl::init(false), cl::Hidden);
68 
69 static cl::opt<bool> HideMemoryTransferLatency(
70     "openmp-hide-memory-transfer-latency",
71     cl::desc("[WIP] Tries to hide the latency of host to device memory"
72              " transfers"),
73     cl::Hidden, cl::init(false));
74 
75 static cl::opt<bool> DisableOpenMPOptDeglobalization(
76     "openmp-opt-disable-deglobalization", cl::ZeroOrMore,
77     cl::desc("Disable OpenMP optimizations involving deglobalization."),
78     cl::Hidden, cl::init(false));
79 
80 static cl::opt<bool> DisableOpenMPOptSPMDization(
81     "openmp-opt-disable-spmdization", cl::ZeroOrMore,
82     cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
83     cl::Hidden, cl::init(false));
84 
85 static cl::opt<bool> DisableOpenMPOptFolding(
86     "openmp-opt-disable-folding", cl::ZeroOrMore,
87     cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
88     cl::init(false));
89 
90 static cl::opt<bool> DisableOpenMPOptStateMachineRewrite(
91     "openmp-opt-disable-state-machine-rewrite", cl::ZeroOrMore,
92     cl::desc("Disable OpenMP optimizations that replace the state machine."),
93     cl::Hidden, cl::init(false));
94 
95 static cl::opt<bool> PrintModuleAfterOptimizations(
96     "openmp-opt-print-module", cl::ZeroOrMore,
97     cl::desc("Print the current module after OpenMP optimizations."),
98     cl::Hidden, cl::init(false));
99 
100 static cl::opt<bool> AlwaysInlineDeviceFunctions(
101     "openmp-opt-inline-device", cl::ZeroOrMore,
102     cl::desc("Inline all applicible functions on the device."), cl::Hidden,
103     cl::init(false));
104 
105 STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
106           "Number of OpenMP runtime calls deduplicated");
107 STATISTIC(NumOpenMPParallelRegionsDeleted,
108           "Number of OpenMP parallel regions deleted");
109 STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
110           "Number of OpenMP runtime functions identified");
111 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
112           "Number of OpenMP runtime function uses identified");
113 STATISTIC(NumOpenMPTargetRegionKernels,
114           "Number of OpenMP target region entry points (=kernels) identified");
115 STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
116           "Number of OpenMP target region entry points (=kernels) executed in "
117           "SPMD-mode instead of generic-mode");
118 STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
119           "Number of OpenMP target region entry points (=kernels) executed in "
120           "generic-mode without a state machines");
121 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
122           "Number of OpenMP target region entry points (=kernels) executed in "
123           "generic-mode with customized state machines with fallback");
124 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
125           "Number of OpenMP target region entry points (=kernels) executed in "
126           "generic-mode with customized state machines without fallback");
127 STATISTIC(
128     NumOpenMPParallelRegionsReplacedInGPUStateMachine,
129     "Number of OpenMP parallel regions replaced with ID in GPU state machines");
130 STATISTIC(NumOpenMPParallelRegionsMerged,
131           "Number of OpenMP parallel regions merged");
132 STATISTIC(NumBytesMovedToSharedMemory,
133           "Amount of memory pushed to shared memory");
134 
135 #if !defined(NDEBUG)
136 static constexpr auto TAG = "[" DEBUG_TYPE "]";
137 #endif
138 
139 namespace {
140 
141 enum class AddressSpace : unsigned {
142   Generic = 0,
143   Global = 1,
144   Shared = 3,
145   Constant = 4,
146   Local = 5,
147 };
148 
149 struct AAHeapToShared;
150 
151 struct AAICVTracker;
152 
153 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
154 /// Attributor runs.
155 struct OMPInformationCache : public InformationCache {
156   OMPInformationCache(Module &M, AnalysisGetter &AG,
157                       BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC,
158                       SmallPtrSetImpl<Kernel> &Kernels)
159       : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M),
160         Kernels(Kernels) {
161 
162     OMPBuilder.initialize();
163     initializeRuntimeFunctions();
164     initializeInternalControlVars();
165   }
166 
167   /// Generic information that describes an internal control variable.
168   struct InternalControlVarInfo {
169     /// The kind, as described by InternalControlVar enum.
170     InternalControlVar Kind;
171 
172     /// The name of the ICV.
173     StringRef Name;
174 
175     /// Environment variable associated with this ICV.
176     StringRef EnvVarName;
177 
178     /// Initial value kind.
179     ICVInitValue InitKind;
180 
181     /// Initial value.
182     ConstantInt *InitValue;
183 
184     /// Setter RTL function associated with this ICV.
185     RuntimeFunction Setter;
186 
187     /// Getter RTL function associated with this ICV.
188     RuntimeFunction Getter;
189 
190     /// RTL Function corresponding to the override clause of this ICV
191     RuntimeFunction Clause;
192   };
193 
194   /// Generic information that describes a runtime function
195   struct RuntimeFunctionInfo {
196 
197     /// The kind, as described by the RuntimeFunction enum.
198     RuntimeFunction Kind;
199 
200     /// The name of the function.
201     StringRef Name;
202 
203     /// Flag to indicate a variadic function.
204     bool IsVarArg;
205 
206     /// The return type of the function.
207     Type *ReturnType;
208 
209     /// The argument types of the function.
210     SmallVector<Type *, 8> ArgumentTypes;
211 
212     /// The declaration if available.
213     Function *Declaration = nullptr;
214 
215     /// Uses of this runtime function per function containing the use.
216     using UseVector = SmallVector<Use *, 16>;
217 
218     /// Clear UsesMap for runtime function.
219     void clearUsesMap() { UsesMap.clear(); }
220 
221     /// Boolean conversion that is true if the runtime function was found.
222     operator bool() const { return Declaration; }
223 
224     /// Return the vector of uses in function \p F.
225     UseVector &getOrCreateUseVector(Function *F) {
226       std::shared_ptr<UseVector> &UV = UsesMap[F];
227       if (!UV)
228         UV = std::make_shared<UseVector>();
229       return *UV;
230     }
231 
232     /// Return the vector of uses in function \p F or `nullptr` if there are
233     /// none.
234     const UseVector *getUseVector(Function &F) const {
235       auto I = UsesMap.find(&F);
236       if (I != UsesMap.end())
237         return I->second.get();
238       return nullptr;
239     }
240 
241     /// Return how many functions contain uses of this runtime function.
242     size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
243 
244     /// Return the number of arguments (or the minimal number for variadic
245     /// functions).
246     size_t getNumArgs() const { return ArgumentTypes.size(); }
247 
248     /// Run the callback \p CB on each use and forget the use if the result is
249     /// true. The callback will be fed the function in which the use was
250     /// encountered as second argument.
251     void foreachUse(SmallVectorImpl<Function *> &SCC,
252                     function_ref<bool(Use &, Function &)> CB) {
253       for (Function *F : SCC)
254         foreachUse(CB, F);
255     }
256 
257     /// Run the callback \p CB on each use within the function \p F and forget
258     /// the use if the result is true.
259     void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
260       SmallVector<unsigned, 8> ToBeDeleted;
261       ToBeDeleted.clear();
262 
263       unsigned Idx = 0;
264       UseVector &UV = getOrCreateUseVector(F);
265 
266       for (Use *U : UV) {
267         if (CB(*U, *F))
268           ToBeDeleted.push_back(Idx);
269         ++Idx;
270       }
271 
272       // Remove the to-be-deleted indices in reverse order as prior
273       // modifications will not modify the smaller indices.
274       while (!ToBeDeleted.empty()) {
275         unsigned Idx = ToBeDeleted.pop_back_val();
276         UV[Idx] = UV.back();
277         UV.pop_back();
278       }
279     }
280 
281   private:
282     /// Map from functions to all uses of this runtime function contained in
283     /// them.
284     DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
285 
286   public:
287     /// Iterators for the uses of this runtime function.
288     decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
289     decltype(UsesMap)::iterator end() { return UsesMap.end(); }
290   };
291 
292   /// An OpenMP-IR-Builder instance
293   OpenMPIRBuilder OMPBuilder;
294 
295   /// Map from runtime function kind to the runtime function description.
296   EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
297                   RuntimeFunction::OMPRTL___last>
298       RFIs;
299 
300   /// Map from function declarations/definitions to their runtime enum type.
301   DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
302 
303   /// Map from ICV kind to the ICV description.
304   EnumeratedArray<InternalControlVarInfo, InternalControlVar,
305                   InternalControlVar::ICV___last>
306       ICVs;
307 
308   /// Helper to initialize all internal control variable information for those
309   /// defined in OMPKinds.def.
310   void initializeInternalControlVars() {
311 #define ICV_RT_SET(_Name, RTL)                                                 \
312   {                                                                            \
313     auto &ICV = ICVs[_Name];                                                   \
314     ICV.Setter = RTL;                                                          \
315   }
316 #define ICV_RT_GET(Name, RTL)                                                  \
317   {                                                                            \
318     auto &ICV = ICVs[Name];                                                    \
319     ICV.Getter = RTL;                                                          \
320   }
321 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init)                           \
322   {                                                                            \
323     auto &ICV = ICVs[Enum];                                                    \
324     ICV.Name = _Name;                                                          \
325     ICV.Kind = Enum;                                                           \
326     ICV.InitKind = Init;                                                       \
327     ICV.EnvVarName = _EnvVarName;                                              \
328     switch (ICV.InitKind) {                                                    \
329     case ICV_IMPLEMENTATION_DEFINED:                                           \
330       ICV.InitValue = nullptr;                                                 \
331       break;                                                                   \
332     case ICV_ZERO:                                                             \
333       ICV.InitValue = ConstantInt::get(                                        \
334           Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0);                \
335       break;                                                                   \
336     case ICV_FALSE:                                                            \
337       ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext());    \
338       break;                                                                   \
339     case ICV_LAST:                                                             \
340       break;                                                                   \
341     }                                                                          \
342   }
343 #include "llvm/Frontend/OpenMP/OMPKinds.def"
344   }
345 
346   /// Returns true if the function declaration \p F matches the runtime
347   /// function types, that is, return type \p RTFRetType, and argument types
348   /// \p RTFArgTypes.
349   static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
350                                   SmallVector<Type *, 8> &RTFArgTypes) {
351     // TODO: We should output information to the user (under debug output
352     //       and via remarks).
353 
354     if (!F)
355       return false;
356     if (F->getReturnType() != RTFRetType)
357       return false;
358     if (F->arg_size() != RTFArgTypes.size())
359       return false;
360 
361     auto RTFTyIt = RTFArgTypes.begin();
362     for (Argument &Arg : F->args()) {
363       if (Arg.getType() != *RTFTyIt)
364         return false;
365 
366       ++RTFTyIt;
367     }
368 
369     return true;
370   }
371 
372   // Helper to collect all uses of the declaration in the UsesMap.
373   unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
374     unsigned NumUses = 0;
375     if (!RFI.Declaration)
376       return NumUses;
377     OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
378 
379     if (CollectStats) {
380       NumOpenMPRuntimeFunctionsIdentified += 1;
381       NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
382     }
383 
384     // TODO: We directly convert uses into proper calls and unknown uses.
385     for (Use &U : RFI.Declaration->uses()) {
386       if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
387         if (ModuleSlice.count(UserI->getFunction())) {
388           RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
389           ++NumUses;
390         }
391       } else {
392         RFI.getOrCreateUseVector(nullptr).push_back(&U);
393         ++NumUses;
394       }
395     }
396     return NumUses;
397   }
398 
399   // Helper function to recollect uses of a runtime function.
400   void recollectUsesForFunction(RuntimeFunction RTF) {
401     auto &RFI = RFIs[RTF];
402     RFI.clearUsesMap();
403     collectUses(RFI, /*CollectStats*/ false);
404   }
405 
406   // Helper function to recollect uses of all runtime functions.
407   void recollectUses() {
408     for (int Idx = 0; Idx < RFIs.size(); ++Idx)
409       recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
410   }
411 
412   /// Helper to initialize all runtime function information for those defined
413   /// in OpenMPKinds.def.
414   void initializeRuntimeFunctions() {
415     Module &M = *((*ModuleSlice.begin())->getParent());
416 
417     // Helper macros for handling __VA_ARGS__ in OMP_RTL
418 #define OMP_TYPE(VarName, ...)                                                 \
419   Type *VarName = OMPBuilder.VarName;                                          \
420   (void)VarName;
421 
422 #define OMP_ARRAY_TYPE(VarName, ...)                                           \
423   ArrayType *VarName##Ty = OMPBuilder.VarName##Ty;                             \
424   (void)VarName##Ty;                                                           \
425   PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy;                     \
426   (void)VarName##PtrTy;
427 
428 #define OMP_FUNCTION_TYPE(VarName, ...)                                        \
429   FunctionType *VarName = OMPBuilder.VarName;                                  \
430   (void)VarName;                                                               \
431   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
432   (void)VarName##Ptr;
433 
434 #define OMP_STRUCT_TYPE(VarName, ...)                                          \
435   StructType *VarName = OMPBuilder.VarName;                                    \
436   (void)VarName;                                                               \
437   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
438   (void)VarName##Ptr;
439 
440 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
441   {                                                                            \
442     SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
443     Function *F = M.getFunction(_Name);                                        \
444     RTLFunctions.insert(F);                                                    \
445     if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) {           \
446       RuntimeFunctionIDMap[F] = _Enum;                                         \
447       F->removeFnAttr(Attribute::NoInline);                                    \
448       auto &RFI = RFIs[_Enum];                                                 \
449       RFI.Kind = _Enum;                                                        \
450       RFI.Name = _Name;                                                        \
451       RFI.IsVarArg = _IsVarArg;                                                \
452       RFI.ReturnType = OMPBuilder._ReturnType;                                 \
453       RFI.ArgumentTypes = std::move(ArgsTypes);                                \
454       RFI.Declaration = F;                                                     \
455       unsigned NumUses = collectUses(RFI);                                     \
456       (void)NumUses;                                                           \
457       LLVM_DEBUG({                                                             \
458         dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
459                << " found\n";                                                  \
460         if (RFI.Declaration)                                                   \
461           dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
462                  << RFI.getNumFunctionsWithUses()                              \
463                  << " different functions.\n";                                 \
464       });                                                                      \
465     }                                                                          \
466   }
467 #include "llvm/Frontend/OpenMP/OMPKinds.def"
468 
469     // TODO: We should attach the attributes defined in OMPKinds.def.
470   }
471 
472   /// Collection of known kernels (\see Kernel) in the module.
473   SmallPtrSetImpl<Kernel> &Kernels;
474 
475   /// Collection of known OpenMP runtime functions..
476   DenseSet<const Function *> RTLFunctions;
477 };
478 
479 template <typename Ty, bool InsertInvalidates = true>
480 struct BooleanStateWithSetVector : public BooleanState {
481   bool contains(const Ty &Elem) const { return Set.contains(Elem); }
482   bool insert(const Ty &Elem) {
483     if (InsertInvalidates)
484       BooleanState::indicatePessimisticFixpoint();
485     return Set.insert(Elem);
486   }
487 
488   const Ty &operator[](int Idx) const { return Set[Idx]; }
489   bool operator==(const BooleanStateWithSetVector &RHS) const {
490     return BooleanState::operator==(RHS) && Set == RHS.Set;
491   }
492   bool operator!=(const BooleanStateWithSetVector &RHS) const {
493     return !(*this == RHS);
494   }
495 
496   bool empty() const { return Set.empty(); }
497   size_t size() const { return Set.size(); }
498 
499   /// "Clamp" this state with \p RHS.
500   BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
501     BooleanState::operator^=(RHS);
502     Set.insert(RHS.Set.begin(), RHS.Set.end());
503     return *this;
504   }
505 
506 private:
507   /// A set to keep track of elements.
508   SetVector<Ty> Set;
509 
510 public:
511   typename decltype(Set)::iterator begin() { return Set.begin(); }
512   typename decltype(Set)::iterator end() { return Set.end(); }
513   typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
514   typename decltype(Set)::const_iterator end() const { return Set.end(); }
515 };
516 
517 template <typename Ty, bool InsertInvalidates = true>
518 using BooleanStateWithPtrSetVector =
519     BooleanStateWithSetVector<Ty *, InsertInvalidates>;
520 
521 struct KernelInfoState : AbstractState {
522   /// Flag to track if we reached a fixpoint.
523   bool IsAtFixpoint = false;
524 
525   /// The parallel regions (identified by the outlined parallel functions) that
526   /// can be reached from the associated function.
527   BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false>
528       ReachedKnownParallelRegions;
529 
530   /// State to track what parallel region we might reach.
531   BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
532 
533   /// State to track if we are in SPMD-mode, assumed or know, and why we decided
534   /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
535   /// false.
536   BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
537 
538   /// The __kmpc_target_init call in this kernel, if any. If we find more than
539   /// one we abort as the kernel is malformed.
540   CallBase *KernelInitCB = nullptr;
541 
542   /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
543   /// one we abort as the kernel is malformed.
544   CallBase *KernelDeinitCB = nullptr;
545 
546   /// Flag to indicate if the associated function is a kernel entry.
547   bool IsKernelEntry = false;
548 
549   /// State to track what kernel entries can reach the associated function.
550   BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
551 
552   /// State to indicate if we can track parallel level of the associated
553   /// function. We will give up tracking if we encounter unknown caller or the
554   /// caller is __kmpc_parallel_51.
555   BooleanStateWithSetVector<uint8_t> ParallelLevels;
556 
557   /// Abstract State interface
558   ///{
559 
560   KernelInfoState() {}
561   KernelInfoState(bool BestState) {
562     if (!BestState)
563       indicatePessimisticFixpoint();
564   }
565 
566   /// See AbstractState::isValidState(...)
567   bool isValidState() const override { return true; }
568 
569   /// See AbstractState::isAtFixpoint(...)
570   bool isAtFixpoint() const override { return IsAtFixpoint; }
571 
572   /// See AbstractState::indicatePessimisticFixpoint(...)
573   ChangeStatus indicatePessimisticFixpoint() override {
574     IsAtFixpoint = true;
575     SPMDCompatibilityTracker.indicatePessimisticFixpoint();
576     ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
577     return ChangeStatus::CHANGED;
578   }
579 
580   /// See AbstractState::indicateOptimisticFixpoint(...)
581   ChangeStatus indicateOptimisticFixpoint() override {
582     IsAtFixpoint = true;
583     return ChangeStatus::UNCHANGED;
584   }
585 
586   /// Return the assumed state
587   KernelInfoState &getAssumed() { return *this; }
588   const KernelInfoState &getAssumed() const { return *this; }
589 
590   bool operator==(const KernelInfoState &RHS) const {
591     if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
592       return false;
593     if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
594       return false;
595     if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
596       return false;
597     if (ReachingKernelEntries != RHS.ReachingKernelEntries)
598       return false;
599     return true;
600   }
601 
602   /// Return empty set as the best state of potential values.
603   static KernelInfoState getBestState() { return KernelInfoState(true); }
604 
605   static KernelInfoState getBestState(KernelInfoState &KIS) {
606     return getBestState();
607   }
608 
609   /// Return full set as the worst state of potential values.
610   static KernelInfoState getWorstState() { return KernelInfoState(false); }
611 
612   /// "Clamp" this state with \p KIS.
613   KernelInfoState operator^=(const KernelInfoState &KIS) {
614     // Do not merge two different _init and _deinit call sites.
615     if (KIS.KernelInitCB) {
616       if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
617         indicatePessimisticFixpoint();
618       KernelInitCB = KIS.KernelInitCB;
619     }
620     if (KIS.KernelDeinitCB) {
621       if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
622         indicatePessimisticFixpoint();
623       KernelDeinitCB = KIS.KernelDeinitCB;
624     }
625     SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
626     ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
627     ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
628     return *this;
629   }
630 
631   KernelInfoState operator&=(const KernelInfoState &KIS) {
632     return (*this ^= KIS);
633   }
634 
635   ///}
636 };
637 
638 /// Used to map the values physically (in the IR) stored in an offload
639 /// array, to a vector in memory.
640 struct OffloadArray {
641   /// Physical array (in the IR).
642   AllocaInst *Array = nullptr;
643   /// Mapped values.
644   SmallVector<Value *, 8> StoredValues;
645   /// Last stores made in the offload array.
646   SmallVector<StoreInst *, 8> LastAccesses;
647 
648   OffloadArray() = default;
649 
650   /// Initializes the OffloadArray with the values stored in \p Array before
651   /// instruction \p Before is reached. Returns false if the initialization
652   /// fails.
653   /// This MUST be used immediately after the construction of the object.
654   bool initialize(AllocaInst &Array, Instruction &Before) {
655     if (!Array.getAllocatedType()->isArrayTy())
656       return false;
657 
658     if (!getValues(Array, Before))
659       return false;
660 
661     this->Array = &Array;
662     return true;
663   }
664 
665   static const unsigned DeviceIDArgNum = 1;
666   static const unsigned BasePtrsArgNum = 3;
667   static const unsigned PtrsArgNum = 4;
668   static const unsigned SizesArgNum = 5;
669 
670 private:
671   /// Traverses the BasicBlock where \p Array is, collecting the stores made to
672   /// \p Array, leaving StoredValues with the values stored before the
673   /// instruction \p Before is reached.
674   bool getValues(AllocaInst &Array, Instruction &Before) {
675     // Initialize container.
676     const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
677     StoredValues.assign(NumValues, nullptr);
678     LastAccesses.assign(NumValues, nullptr);
679 
680     // TODO: This assumes the instruction \p Before is in the same
681     //  BasicBlock as Array. Make it general, for any control flow graph.
682     BasicBlock *BB = Array.getParent();
683     if (BB != Before.getParent())
684       return false;
685 
686     const DataLayout &DL = Array.getModule()->getDataLayout();
687     const unsigned int PointerSize = DL.getPointerSize();
688 
689     for (Instruction &I : *BB) {
690       if (&I == &Before)
691         break;
692 
693       if (!isa<StoreInst>(&I))
694         continue;
695 
696       auto *S = cast<StoreInst>(&I);
697       int64_t Offset = -1;
698       auto *Dst =
699           GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
700       if (Dst == &Array) {
701         int64_t Idx = Offset / PointerSize;
702         StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
703         LastAccesses[Idx] = S;
704       }
705     }
706 
707     return isFilled();
708   }
709 
710   /// Returns true if all values in StoredValues and
711   /// LastAccesses are not nullptrs.
712   bool isFilled() {
713     const unsigned NumValues = StoredValues.size();
714     for (unsigned I = 0; I < NumValues; ++I) {
715       if (!StoredValues[I] || !LastAccesses[I])
716         return false;
717     }
718 
719     return true;
720   }
721 };
722 
723 struct OpenMPOpt {
724 
725   using OptimizationRemarkGetter =
726       function_ref<OptimizationRemarkEmitter &(Function *)>;
727 
728   OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
729             OptimizationRemarkGetter OREGetter,
730             OMPInformationCache &OMPInfoCache, Attributor &A)
731       : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
732         OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
733 
734   /// Check if any remarks are enabled for openmp-opt
735   bool remarksEnabled() {
736     auto &Ctx = M.getContext();
737     return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
738   }
739 
740   /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
741   bool run(bool IsModulePass) {
742     if (SCC.empty())
743       return false;
744 
745     bool Changed = false;
746 
747     LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
748                       << " functions in a slice with "
749                       << OMPInfoCache.ModuleSlice.size() << " functions\n");
750 
751     if (IsModulePass) {
752       Changed |= runAttributor(IsModulePass);
753 
754       // Recollect uses, in case Attributor deleted any.
755       OMPInfoCache.recollectUses();
756 
757       // TODO: This should be folded into buildCustomStateMachine.
758       Changed |= rewriteDeviceCodeStateMachine();
759 
760       if (remarksEnabled())
761         analysisGlobalization();
762     } else {
763       if (PrintICVValues)
764         printICVs();
765       if (PrintOpenMPKernels)
766         printKernels();
767 
768       Changed |= runAttributor(IsModulePass);
769 
770       // Recollect uses, in case Attributor deleted any.
771       OMPInfoCache.recollectUses();
772 
773       Changed |= deleteParallelRegions();
774 
775       if (HideMemoryTransferLatency)
776         Changed |= hideMemTransfersLatency();
777       Changed |= deduplicateRuntimeCalls();
778       if (EnableParallelRegionMerging) {
779         if (mergeParallelRegions()) {
780           deduplicateRuntimeCalls();
781           Changed = true;
782         }
783       }
784     }
785 
786     return Changed;
787   }
788 
789   /// Print initial ICV values for testing.
790   /// FIXME: This should be done from the Attributor once it is added.
791   void printICVs() const {
792     InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
793                                  ICV_proc_bind};
794 
795     for (Function *F : OMPInfoCache.ModuleSlice) {
796       for (auto ICV : ICVs) {
797         auto ICVInfo = OMPInfoCache.ICVs[ICV];
798         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
799           return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
800                      << " Value: "
801                      << (ICVInfo.InitValue
802                              ? toString(ICVInfo.InitValue->getValue(), 10, true)
803                              : "IMPLEMENTATION_DEFINED");
804         };
805 
806         emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
807       }
808     }
809   }
810 
811   /// Print OpenMP GPU kernels for testing.
812   void printKernels() const {
813     for (Function *F : SCC) {
814       if (!OMPInfoCache.Kernels.count(F))
815         continue;
816 
817       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
818         return ORA << "OpenMP GPU kernel "
819                    << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
820       };
821 
822       emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
823     }
824   }
825 
826   /// Return the call if \p U is a callee use in a regular call. If \p RFI is
827   /// given it has to be the callee or a nullptr is returned.
828   static CallInst *getCallIfRegularCall(
829       Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
830     CallInst *CI = dyn_cast<CallInst>(U.getUser());
831     if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
832         (!RFI ||
833          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
834       return CI;
835     return nullptr;
836   }
837 
838   /// Return the call if \p V is a regular call. If \p RFI is given it has to be
839   /// the callee or a nullptr is returned.
840   static CallInst *getCallIfRegularCall(
841       Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
842     CallInst *CI = dyn_cast<CallInst>(&V);
843     if (CI && !CI->hasOperandBundles() &&
844         (!RFI ||
845          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
846       return CI;
847     return nullptr;
848   }
849 
850 private:
851   /// Merge parallel regions when it is safe.
852   bool mergeParallelRegions() {
853     const unsigned CallbackCalleeOperand = 2;
854     const unsigned CallbackFirstArgOperand = 3;
855     using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
856 
857     // Check if there are any __kmpc_fork_call calls to merge.
858     OMPInformationCache::RuntimeFunctionInfo &RFI =
859         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
860 
861     if (!RFI.Declaration)
862       return false;
863 
864     // Unmergable calls that prevent merging a parallel region.
865     OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
866         OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
867         OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
868     };
869 
870     bool Changed = false;
871     LoopInfo *LI = nullptr;
872     DominatorTree *DT = nullptr;
873 
874     SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
875 
876     BasicBlock *StartBB = nullptr, *EndBB = nullptr;
877     auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
878                          BasicBlock &ContinuationIP) {
879       BasicBlock *CGStartBB = CodeGenIP.getBlock();
880       BasicBlock *CGEndBB =
881           SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
882       assert(StartBB != nullptr && "StartBB should not be null");
883       CGStartBB->getTerminator()->setSuccessor(0, StartBB);
884       assert(EndBB != nullptr && "EndBB should not be null");
885       EndBB->getTerminator()->setSuccessor(0, CGEndBB);
886     };
887 
888     auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
889                       Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
890       ReplacementValue = &Inner;
891       return CodeGenIP;
892     };
893 
894     auto FiniCB = [&](InsertPointTy CodeGenIP) {};
895 
896     /// Create a sequential execution region within a merged parallel region,
897     /// encapsulated in a master construct with a barrier for synchronization.
898     auto CreateSequentialRegion = [&](Function *OuterFn,
899                                       BasicBlock *OuterPredBB,
900                                       Instruction *SeqStartI,
901                                       Instruction *SeqEndI) {
902       // Isolate the instructions of the sequential region to a separate
903       // block.
904       BasicBlock *ParentBB = SeqStartI->getParent();
905       BasicBlock *SeqEndBB =
906           SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
907       BasicBlock *SeqAfterBB =
908           SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
909       BasicBlock *SeqStartBB =
910           SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
911 
912       assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
913              "Expected a different CFG");
914       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
915       ParentBB->getTerminator()->eraseFromParent();
916 
917       auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
918                            BasicBlock &ContinuationIP) {
919         BasicBlock *CGStartBB = CodeGenIP.getBlock();
920         BasicBlock *CGEndBB =
921             SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
922         assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
923         CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
924         assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
925         SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
926       };
927       auto FiniCB = [&](InsertPointTy CodeGenIP) {};
928 
929       // Find outputs from the sequential region to outside users and
930       // broadcast their values to them.
931       for (Instruction &I : *SeqStartBB) {
932         SmallPtrSet<Instruction *, 4> OutsideUsers;
933         for (User *Usr : I.users()) {
934           Instruction &UsrI = *cast<Instruction>(Usr);
935           // Ignore outputs to LT intrinsics, code extraction for the merged
936           // parallel region will fix them.
937           if (UsrI.isLifetimeStartOrEnd())
938             continue;
939 
940           if (UsrI.getParent() != SeqStartBB)
941             OutsideUsers.insert(&UsrI);
942         }
943 
944         if (OutsideUsers.empty())
945           continue;
946 
947         // Emit an alloca in the outer region to store the broadcasted
948         // value.
949         const DataLayout &DL = M.getDataLayout();
950         AllocaInst *AllocaI = new AllocaInst(
951             I.getType(), DL.getAllocaAddrSpace(), nullptr,
952             I.getName() + ".seq.output.alloc", &OuterFn->front().front());
953 
954         // Emit a store instruction in the sequential BB to update the
955         // value.
956         new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
957 
958         // Emit a load instruction and replace the use of the output value
959         // with it.
960         for (Instruction *UsrI : OutsideUsers) {
961           LoadInst *LoadI = new LoadInst(
962               I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
963           UsrI->replaceUsesOfWith(&I, LoadI);
964         }
965       }
966 
967       OpenMPIRBuilder::LocationDescription Loc(
968           InsertPointTy(ParentBB, ParentBB->end()), DL);
969       InsertPointTy SeqAfterIP =
970           OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
971 
972       OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
973 
974       BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
975 
976       LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
977                         << "\n");
978     };
979 
980     // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
981     // contained in BB and only separated by instructions that can be
982     // redundantly executed in parallel. The block BB is split before the first
983     // call (in MergableCIs) and after the last so the entire region we merge
984     // into a single parallel region is contained in a single basic block
985     // without any other instructions. We use the OpenMPIRBuilder to outline
986     // that block and call the resulting function via __kmpc_fork_call.
987     auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) {
988       // TODO: Change the interface to allow single CIs expanded, e.g, to
989       // include an outer loop.
990       assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
991 
992       auto Remark = [&](OptimizationRemark OR) {
993         OR << "Parallel region merged with parallel region"
994            << (MergableCIs.size() > 2 ? "s" : "") << " at ";
995         for (auto *CI : llvm::drop_begin(MergableCIs)) {
996           OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
997           if (CI != MergableCIs.back())
998             OR << ", ";
999         }
1000         return OR << ".";
1001       };
1002 
1003       emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1004 
1005       Function *OriginalFn = BB->getParent();
1006       LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1007                         << " parallel regions in " << OriginalFn->getName()
1008                         << "\n");
1009 
1010       // Isolate the calls to merge in a separate block.
1011       EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1012       BasicBlock *AfterBB =
1013           SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1014       StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1015                            "omp.par.merged");
1016 
1017       assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1018       const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1019       BB->getTerminator()->eraseFromParent();
1020 
1021       // Create sequential regions for sequential instructions that are
1022       // in-between mergable parallel regions.
1023       for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1024            It != End; ++It) {
1025         Instruction *ForkCI = *It;
1026         Instruction *NextForkCI = *(It + 1);
1027 
1028         // Continue if there are not in-between instructions.
1029         if (ForkCI->getNextNode() == NextForkCI)
1030           continue;
1031 
1032         CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1033                                NextForkCI->getPrevNode());
1034       }
1035 
1036       OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1037                                                DL);
1038       IRBuilder<>::InsertPoint AllocaIP(
1039           &OriginalFn->getEntryBlock(),
1040           OriginalFn->getEntryBlock().getFirstInsertionPt());
1041       // Create the merged parallel region with default proc binding, to
1042       // avoid overriding binding settings, and without explicit cancellation.
1043       InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1044           Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1045           OMP_PROC_BIND_default, /* IsCancellable */ false);
1046       BranchInst::Create(AfterBB, AfterIP.getBlock());
1047 
1048       // Perform the actual outlining.
1049       OMPInfoCache.OMPBuilder.finalize(OriginalFn,
1050                                        /* AllowExtractorSinking */ true);
1051 
1052       Function *OutlinedFn = MergableCIs.front()->getCaller();
1053 
1054       // Replace the __kmpc_fork_call calls with direct calls to the outlined
1055       // callbacks.
1056       SmallVector<Value *, 8> Args;
1057       for (auto *CI : MergableCIs) {
1058         Value *Callee =
1059             CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts();
1060         FunctionType *FT =
1061             cast<FunctionType>(Callee->getType()->getPointerElementType());
1062         Args.clear();
1063         Args.push_back(OutlinedFn->getArg(0));
1064         Args.push_back(OutlinedFn->getArg(1));
1065         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
1066              U < E; ++U)
1067           Args.push_back(CI->getArgOperand(U));
1068 
1069         CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
1070         if (CI->getDebugLoc())
1071           NewCI->setDebugLoc(CI->getDebugLoc());
1072 
1073         // Forward parameter attributes from the callback to the callee.
1074         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
1075              U < E; ++U)
1076           for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1077             NewCI->addParamAttr(
1078                 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1079 
1080         // Emit an explicit barrier to replace the implicit fork-join barrier.
1081         if (CI != MergableCIs.back()) {
1082           // TODO: Remove barrier if the merged parallel region includes the
1083           // 'nowait' clause.
1084           OMPInfoCache.OMPBuilder.createBarrier(
1085               InsertPointTy(NewCI->getParent(),
1086                             NewCI->getNextNode()->getIterator()),
1087               OMPD_parallel);
1088         }
1089 
1090         CI->eraseFromParent();
1091       }
1092 
1093       assert(OutlinedFn != OriginalFn && "Outlining failed");
1094       CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1095       CGUpdater.reanalyzeFunction(*OriginalFn);
1096 
1097       NumOpenMPParallelRegionsMerged += MergableCIs.size();
1098 
1099       return true;
1100     };
1101 
1102     // Helper function that identifes sequences of
1103     // __kmpc_fork_call uses in a basic block.
1104     auto DetectPRsCB = [&](Use &U, Function &F) {
1105       CallInst *CI = getCallIfRegularCall(U, &RFI);
1106       BB2PRMap[CI->getParent()].insert(CI);
1107 
1108       return false;
1109     };
1110 
1111     BB2PRMap.clear();
1112     RFI.foreachUse(SCC, DetectPRsCB);
1113     SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1114     // Find mergable parallel regions within a basic block that are
1115     // safe to merge, that is any in-between instructions can safely
1116     // execute in parallel after merging.
1117     // TODO: support merging across basic-blocks.
1118     for (auto &It : BB2PRMap) {
1119       auto &CIs = It.getSecond();
1120       if (CIs.size() < 2)
1121         continue;
1122 
1123       BasicBlock *BB = It.getFirst();
1124       SmallVector<CallInst *, 4> MergableCIs;
1125 
1126       /// Returns true if the instruction is mergable, false otherwise.
1127       /// A terminator instruction is unmergable by definition since merging
1128       /// works within a BB. Instructions before the mergable region are
1129       /// mergable if they are not calls to OpenMP runtime functions that may
1130       /// set different execution parameters for subsequent parallel regions.
1131       /// Instructions in-between parallel regions are mergable if they are not
1132       /// calls to any non-intrinsic function since that may call a non-mergable
1133       /// OpenMP runtime function.
1134       auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1135         // We do not merge across BBs, hence return false (unmergable) if the
1136         // instruction is a terminator.
1137         if (I.isTerminator())
1138           return false;
1139 
1140         if (!isa<CallInst>(&I))
1141           return true;
1142 
1143         CallInst *CI = cast<CallInst>(&I);
1144         if (IsBeforeMergableRegion) {
1145           Function *CalledFunction = CI->getCalledFunction();
1146           if (!CalledFunction)
1147             return false;
1148           // Return false (unmergable) if the call before the parallel
1149           // region calls an explicit affinity (proc_bind) or number of
1150           // threads (num_threads) compiler-generated function. Those settings
1151           // may be incompatible with following parallel regions.
1152           // TODO: ICV tracking to detect compatibility.
1153           for (const auto &RFI : UnmergableCallsInfo) {
1154             if (CalledFunction == RFI.Declaration)
1155               return false;
1156           }
1157         } else {
1158           // Return false (unmergable) if there is a call instruction
1159           // in-between parallel regions when it is not an intrinsic. It
1160           // may call an unmergable OpenMP runtime function in its callpath.
1161           // TODO: Keep track of possible OpenMP calls in the callpath.
1162           if (!isa<IntrinsicInst>(CI))
1163             return false;
1164         }
1165 
1166         return true;
1167       };
1168       // Find maximal number of parallel region CIs that are safe to merge.
1169       for (auto It = BB->begin(), End = BB->end(); It != End;) {
1170         Instruction &I = *It;
1171         ++It;
1172 
1173         if (CIs.count(&I)) {
1174           MergableCIs.push_back(cast<CallInst>(&I));
1175           continue;
1176         }
1177 
1178         // Continue expanding if the instruction is mergable.
1179         if (IsMergable(I, MergableCIs.empty()))
1180           continue;
1181 
1182         // Forward the instruction iterator to skip the next parallel region
1183         // since there is an unmergable instruction which can affect it.
1184         for (; It != End; ++It) {
1185           Instruction &SkipI = *It;
1186           if (CIs.count(&SkipI)) {
1187             LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1188                               << " due to " << I << "\n");
1189             ++It;
1190             break;
1191           }
1192         }
1193 
1194         // Store mergable regions found.
1195         if (MergableCIs.size() > 1) {
1196           MergableCIsVector.push_back(MergableCIs);
1197           LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1198                             << " parallel regions in block " << BB->getName()
1199                             << " of function " << BB->getParent()->getName()
1200                             << "\n";);
1201         }
1202 
1203         MergableCIs.clear();
1204       }
1205 
1206       if (!MergableCIsVector.empty()) {
1207         Changed = true;
1208 
1209         for (auto &MergableCIs : MergableCIsVector)
1210           Merge(MergableCIs, BB);
1211         MergableCIsVector.clear();
1212       }
1213     }
1214 
1215     if (Changed) {
1216       /// Re-collect use for fork calls, emitted barrier calls, and
1217       /// any emitted master/end_master calls.
1218       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1219       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1220       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1221       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1222     }
1223 
1224     return Changed;
1225   }
1226 
1227   /// Try to delete parallel regions if possible.
1228   bool deleteParallelRegions() {
1229     const unsigned CallbackCalleeOperand = 2;
1230 
1231     OMPInformationCache::RuntimeFunctionInfo &RFI =
1232         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1233 
1234     if (!RFI.Declaration)
1235       return false;
1236 
1237     bool Changed = false;
1238     auto DeleteCallCB = [&](Use &U, Function &) {
1239       CallInst *CI = getCallIfRegularCall(U);
1240       if (!CI)
1241         return false;
1242       auto *Fn = dyn_cast<Function>(
1243           CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1244       if (!Fn)
1245         return false;
1246       if (!Fn->onlyReadsMemory())
1247         return false;
1248       if (!Fn->hasFnAttribute(Attribute::WillReturn))
1249         return false;
1250 
1251       LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1252                         << CI->getCaller()->getName() << "\n");
1253 
1254       auto Remark = [&](OptimizationRemark OR) {
1255         return OR << "Removing parallel region with no side-effects.";
1256       };
1257       emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1258 
1259       CGUpdater.removeCallSite(*CI);
1260       CI->eraseFromParent();
1261       Changed = true;
1262       ++NumOpenMPParallelRegionsDeleted;
1263       return true;
1264     };
1265 
1266     RFI.foreachUse(SCC, DeleteCallCB);
1267 
1268     return Changed;
1269   }
1270 
1271   /// Try to eliminate runtime calls by reusing existing ones.
1272   bool deduplicateRuntimeCalls() {
1273     bool Changed = false;
1274 
1275     RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1276         OMPRTL_omp_get_num_threads,
1277         OMPRTL_omp_in_parallel,
1278         OMPRTL_omp_get_cancellation,
1279         OMPRTL_omp_get_thread_limit,
1280         OMPRTL_omp_get_supported_active_levels,
1281         OMPRTL_omp_get_level,
1282         OMPRTL_omp_get_ancestor_thread_num,
1283         OMPRTL_omp_get_team_size,
1284         OMPRTL_omp_get_active_level,
1285         OMPRTL_omp_in_final,
1286         OMPRTL_omp_get_proc_bind,
1287         OMPRTL_omp_get_num_places,
1288         OMPRTL_omp_get_num_procs,
1289         OMPRTL_omp_get_place_num,
1290         OMPRTL_omp_get_partition_num_places,
1291         OMPRTL_omp_get_partition_place_nums};
1292 
1293     // Global-tid is handled separately.
1294     SmallSetVector<Value *, 16> GTIdArgs;
1295     collectGlobalThreadIdArguments(GTIdArgs);
1296     LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1297                       << " global thread ID arguments\n");
1298 
1299     for (Function *F : SCC) {
1300       for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1301         Changed |= deduplicateRuntimeCalls(
1302             *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1303 
1304       // __kmpc_global_thread_num is special as we can replace it with an
1305       // argument in enough cases to make it worth trying.
1306       Value *GTIdArg = nullptr;
1307       for (Argument &Arg : F->args())
1308         if (GTIdArgs.count(&Arg)) {
1309           GTIdArg = &Arg;
1310           break;
1311         }
1312       Changed |= deduplicateRuntimeCalls(
1313           *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1314     }
1315 
1316     return Changed;
1317   }
1318 
1319   /// Tries to hide the latency of runtime calls that involve host to
1320   /// device memory transfers by splitting them into their "issue" and "wait"
1321   /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1322   /// moved downards as much as possible. The "issue" issues the memory transfer
1323   /// asynchronously, returning a handle. The "wait" waits in the returned
1324   /// handle for the memory transfer to finish.
1325   bool hideMemTransfersLatency() {
1326     auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1327     bool Changed = false;
1328     auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1329       auto *RTCall = getCallIfRegularCall(U, &RFI);
1330       if (!RTCall)
1331         return false;
1332 
1333       OffloadArray OffloadArrays[3];
1334       if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1335         return false;
1336 
1337       LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1338 
1339       // TODO: Check if can be moved upwards.
1340       bool WasSplit = false;
1341       Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1342       if (WaitMovementPoint)
1343         WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1344 
1345       Changed |= WasSplit;
1346       return WasSplit;
1347     };
1348     RFI.foreachUse(SCC, SplitMemTransfers);
1349 
1350     return Changed;
1351   }
1352 
1353   void analysisGlobalization() {
1354     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1355 
1356     auto CheckGlobalization = [&](Use &U, Function &Decl) {
1357       if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1358         auto Remark = [&](OptimizationRemarkMissed ORM) {
1359           return ORM
1360                  << "Found thread data sharing on the GPU. "
1361                  << "Expect degraded performance due to data globalization.";
1362         };
1363         emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1364       }
1365 
1366       return false;
1367     };
1368 
1369     RFI.foreachUse(SCC, CheckGlobalization);
1370   }
1371 
1372   /// Maps the values stored in the offload arrays passed as arguments to
1373   /// \p RuntimeCall into the offload arrays in \p OAs.
1374   bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1375                                 MutableArrayRef<OffloadArray> OAs) {
1376     assert(OAs.size() == 3 && "Need space for three offload arrays!");
1377 
1378     // A runtime call that involves memory offloading looks something like:
1379     // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1380     //   i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1381     // ...)
1382     // So, the idea is to access the allocas that allocate space for these
1383     // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1384     // Therefore:
1385     // i8** %offload_baseptrs.
1386     Value *BasePtrsArg =
1387         RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1388     // i8** %offload_ptrs.
1389     Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1390     // i8** %offload_sizes.
1391     Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1392 
1393     // Get values stored in **offload_baseptrs.
1394     auto *V = getUnderlyingObject(BasePtrsArg);
1395     if (!isa<AllocaInst>(V))
1396       return false;
1397     auto *BasePtrsArray = cast<AllocaInst>(V);
1398     if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1399       return false;
1400 
1401     // Get values stored in **offload_baseptrs.
1402     V = getUnderlyingObject(PtrsArg);
1403     if (!isa<AllocaInst>(V))
1404       return false;
1405     auto *PtrsArray = cast<AllocaInst>(V);
1406     if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1407       return false;
1408 
1409     // Get values stored in **offload_sizes.
1410     V = getUnderlyingObject(SizesArg);
1411     // If it's a [constant] global array don't analyze it.
1412     if (isa<GlobalValue>(V))
1413       return isa<Constant>(V);
1414     if (!isa<AllocaInst>(V))
1415       return false;
1416 
1417     auto *SizesArray = cast<AllocaInst>(V);
1418     if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1419       return false;
1420 
1421     return true;
1422   }
1423 
1424   /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1425   /// For now this is a way to test that the function getValuesInOffloadArrays
1426   /// is working properly.
1427   /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1428   void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1429     assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1430 
1431     LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1432     std::string ValuesStr;
1433     raw_string_ostream Printer(ValuesStr);
1434     std::string Separator = " --- ";
1435 
1436     for (auto *BP : OAs[0].StoredValues) {
1437       BP->print(Printer);
1438       Printer << Separator;
1439     }
1440     LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
1441     ValuesStr.clear();
1442 
1443     for (auto *P : OAs[1].StoredValues) {
1444       P->print(Printer);
1445       Printer << Separator;
1446     }
1447     LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
1448     ValuesStr.clear();
1449 
1450     for (auto *S : OAs[2].StoredValues) {
1451       S->print(Printer);
1452       Printer << Separator;
1453     }
1454     LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
1455   }
1456 
1457   /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1458   /// moved. Returns nullptr if the movement is not possible, or not worth it.
1459   Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1460     // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1461     //  Make it traverse the CFG.
1462 
1463     Instruction *CurrentI = &RuntimeCall;
1464     bool IsWorthIt = false;
1465     while ((CurrentI = CurrentI->getNextNode())) {
1466 
1467       // TODO: Once we detect the regions to be offloaded we should use the
1468       //  alias analysis manager to check if CurrentI may modify one of
1469       //  the offloaded regions.
1470       if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1471         if (IsWorthIt)
1472           return CurrentI;
1473 
1474         return nullptr;
1475       }
1476 
1477       // FIXME: For now if we move it over anything without side effect
1478       //  is worth it.
1479       IsWorthIt = true;
1480     }
1481 
1482     // Return end of BasicBlock.
1483     return RuntimeCall.getParent()->getTerminator();
1484   }
1485 
1486   /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1487   bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1488                                Instruction &WaitMovementPoint) {
1489     // Create stack allocated handle (__tgt_async_info) at the beginning of the
1490     // function. Used for storing information of the async transfer, allowing to
1491     // wait on it later.
1492     auto &IRBuilder = OMPInfoCache.OMPBuilder;
1493     auto *F = RuntimeCall.getCaller();
1494     Instruction *FirstInst = &(F->getEntryBlock().front());
1495     AllocaInst *Handle = new AllocaInst(
1496         IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst);
1497 
1498     // Add "issue" runtime call declaration:
1499     // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1500     //   i8**, i8**, i64*, i64*)
1501     FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1502         M, OMPRTL___tgt_target_data_begin_mapper_issue);
1503 
1504     // Change RuntimeCall call site for its asynchronous version.
1505     SmallVector<Value *, 16> Args;
1506     for (auto &Arg : RuntimeCall.args())
1507       Args.push_back(Arg.get());
1508     Args.push_back(Handle);
1509 
1510     CallInst *IssueCallsite =
1511         CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
1512     RuntimeCall.eraseFromParent();
1513 
1514     // Add "wait" runtime call declaration:
1515     // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1516     FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1517         M, OMPRTL___tgt_target_data_begin_mapper_wait);
1518 
1519     Value *WaitParams[2] = {
1520         IssueCallsite->getArgOperand(
1521             OffloadArray::DeviceIDArgNum), // device_id.
1522         Handle                             // handle to wait on.
1523     };
1524     CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
1525 
1526     return true;
1527   }
1528 
1529   static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1530                                     bool GlobalOnly, bool &SingleChoice) {
1531     if (CurrentIdent == NextIdent)
1532       return CurrentIdent;
1533 
1534     // TODO: Figure out how to actually combine multiple debug locations. For
1535     //       now we just keep an existing one if there is a single choice.
1536     if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1537       SingleChoice = !CurrentIdent;
1538       return NextIdent;
1539     }
1540     return nullptr;
1541   }
1542 
1543   /// Return an `struct ident_t*` value that represents the ones used in the
1544   /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1545   /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1546   /// return value we create one from scratch. We also do not yet combine
1547   /// information, e.g., the source locations, see combinedIdentStruct.
1548   Value *
1549   getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1550                                  Function &F, bool GlobalOnly) {
1551     bool SingleChoice = true;
1552     Value *Ident = nullptr;
1553     auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1554       CallInst *CI = getCallIfRegularCall(U, &RFI);
1555       if (!CI || &F != &Caller)
1556         return false;
1557       Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1558                                   /* GlobalOnly */ true, SingleChoice);
1559       return false;
1560     };
1561     RFI.foreachUse(SCC, CombineIdentStruct);
1562 
1563     if (!Ident || !SingleChoice) {
1564       // The IRBuilder uses the insertion block to get to the module, this is
1565       // unfortunate but we work around it for now.
1566       if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1567         OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1568             &F.getEntryBlock(), F.getEntryBlock().begin()));
1569       // Create a fallback location if non was found.
1570       // TODO: Use the debug locations of the calls instead.
1571       Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr();
1572       Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc);
1573     }
1574     return Ident;
1575   }
1576 
1577   /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1578   /// \p ReplVal if given.
1579   bool deduplicateRuntimeCalls(Function &F,
1580                                OMPInformationCache::RuntimeFunctionInfo &RFI,
1581                                Value *ReplVal = nullptr) {
1582     auto *UV = RFI.getUseVector(F);
1583     if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1584       return false;
1585 
1586     LLVM_DEBUG(
1587         dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1588                << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1589 
1590     assert((!ReplVal || (isa<Argument>(ReplVal) &&
1591                          cast<Argument>(ReplVal)->getParent() == &F)) &&
1592            "Unexpected replacement value!");
1593 
1594     // TODO: Use dominance to find a good position instead.
1595     auto CanBeMoved = [this](CallBase &CB) {
1596       unsigned NumArgs = CB.getNumArgOperands();
1597       if (NumArgs == 0)
1598         return true;
1599       if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1600         return false;
1601       for (unsigned u = 1; u < NumArgs; ++u)
1602         if (isa<Instruction>(CB.getArgOperand(u)))
1603           return false;
1604       return true;
1605     };
1606 
1607     if (!ReplVal) {
1608       for (Use *U : *UV)
1609         if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1610           if (!CanBeMoved(*CI))
1611             continue;
1612 
1613           // If the function is a kernel, dedup will move
1614           // the runtime call right after the kernel init callsite. Otherwise,
1615           // it will move it to the beginning of the caller function.
1616           if (isKernel(F)) {
1617             auto &KernelInitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
1618             auto *KernelInitUV = KernelInitRFI.getUseVector(F);
1619 
1620             if (KernelInitUV->empty())
1621               continue;
1622 
1623             assert(KernelInitUV->size() == 1 &&
1624                    "Expected a single __kmpc_target_init in kernel\n");
1625 
1626             CallInst *KernelInitCI =
1627                 getCallIfRegularCall(*KernelInitUV->front(), &KernelInitRFI);
1628             assert(KernelInitCI &&
1629                    "Expected a call to __kmpc_target_init in kernel\n");
1630 
1631             CI->moveAfter(KernelInitCI);
1632           } else
1633             CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
1634           ReplVal = CI;
1635           break;
1636         }
1637       if (!ReplVal)
1638         return false;
1639     }
1640 
1641     // If we use a call as a replacement value we need to make sure the ident is
1642     // valid at the new location. For now we just pick a global one, either
1643     // existing and used by one of the calls, or created from scratch.
1644     if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1645       if (CI->getNumArgOperands() > 0 &&
1646           CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1647         Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1648                                                       /* GlobalOnly */ true);
1649         CI->setArgOperand(0, Ident);
1650       }
1651     }
1652 
1653     bool Changed = false;
1654     auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1655       CallInst *CI = getCallIfRegularCall(U, &RFI);
1656       if (!CI || CI == ReplVal || &F != &Caller)
1657         return false;
1658       assert(CI->getCaller() == &F && "Unexpected call!");
1659 
1660       auto Remark = [&](OptimizationRemark OR) {
1661         return OR << "OpenMP runtime call "
1662                   << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1663       };
1664       if (CI->getDebugLoc())
1665         emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1666       else
1667         emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1668 
1669       CGUpdater.removeCallSite(*CI);
1670       CI->replaceAllUsesWith(ReplVal);
1671       CI->eraseFromParent();
1672       ++NumOpenMPRuntimeCallsDeduplicated;
1673       Changed = true;
1674       return true;
1675     };
1676     RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1677 
1678     return Changed;
1679   }
1680 
1681   /// Collect arguments that represent the global thread id in \p GTIdArgs.
1682   void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1683     // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1684     //       initialization. We could define an AbstractAttribute instead and
1685     //       run the Attributor here once it can be run as an SCC pass.
1686 
1687     // Helper to check the argument \p ArgNo at all call sites of \p F for
1688     // a GTId.
1689     auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1690       if (!F.hasLocalLinkage())
1691         return false;
1692       for (Use &U : F.uses()) {
1693         if (CallInst *CI = getCallIfRegularCall(U)) {
1694           Value *ArgOp = CI->getArgOperand(ArgNo);
1695           if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1696               getCallIfRegularCall(
1697                   *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1698             continue;
1699         }
1700         return false;
1701       }
1702       return true;
1703     };
1704 
1705     // Helper to identify uses of a GTId as GTId arguments.
1706     auto AddUserArgs = [&](Value &GTId) {
1707       for (Use &U : GTId.uses())
1708         if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1709           if (CI->isArgOperand(&U))
1710             if (Function *Callee = CI->getCalledFunction())
1711               if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1712                 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1713     };
1714 
1715     // The argument users of __kmpc_global_thread_num calls are GTIds.
1716     OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1717         OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1718 
1719     GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1720       if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1721         AddUserArgs(*CI);
1722       return false;
1723     });
1724 
1725     // Transitively search for more arguments by looking at the users of the
1726     // ones we know already. During the search the GTIdArgs vector is extended
1727     // so we cannot cache the size nor can we use a range based for.
1728     for (unsigned u = 0; u < GTIdArgs.size(); ++u)
1729       AddUserArgs(*GTIdArgs[u]);
1730   }
1731 
1732   /// Kernel (=GPU) optimizations and utility functions
1733   ///
1734   ///{{
1735 
1736   /// Check if \p F is a kernel, hence entry point for target offloading.
1737   bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
1738 
1739   /// Cache to remember the unique kernel for a function.
1740   DenseMap<Function *, Optional<Kernel>> UniqueKernelMap;
1741 
1742   /// Find the unique kernel that will execute \p F, if any.
1743   Kernel getUniqueKernelFor(Function &F);
1744 
1745   /// Find the unique kernel that will execute \p I, if any.
1746   Kernel getUniqueKernelFor(Instruction &I) {
1747     return getUniqueKernelFor(*I.getFunction());
1748   }
1749 
1750   /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1751   /// the cases we can avoid taking the address of a function.
1752   bool rewriteDeviceCodeStateMachine();
1753 
1754   ///
1755   ///}}
1756 
1757   /// Emit a remark generically
1758   ///
1759   /// This template function can be used to generically emit a remark. The
1760   /// RemarkKind should be one of the following:
1761   ///   - OptimizationRemark to indicate a successful optimization attempt
1762   ///   - OptimizationRemarkMissed to report a failed optimization attempt
1763   ///   - OptimizationRemarkAnalysis to provide additional information about an
1764   ///     optimization attempt
1765   ///
1766   /// The remark is built using a callback function provided by the caller that
1767   /// takes a RemarkKind as input and returns a RemarkKind.
1768   template <typename RemarkKind, typename RemarkCallBack>
1769   void emitRemark(Instruction *I, StringRef RemarkName,
1770                   RemarkCallBack &&RemarkCB) const {
1771     Function *F = I->getParent()->getParent();
1772     auto &ORE = OREGetter(F);
1773 
1774     if (RemarkName.startswith("OMP"))
1775       ORE.emit([&]() {
1776         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
1777                << " [" << RemarkName << "]";
1778       });
1779     else
1780       ORE.emit(
1781           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
1782   }
1783 
1784   /// Emit a remark on a function.
1785   template <typename RemarkKind, typename RemarkCallBack>
1786   void emitRemark(Function *F, StringRef RemarkName,
1787                   RemarkCallBack &&RemarkCB) const {
1788     auto &ORE = OREGetter(F);
1789 
1790     if (RemarkName.startswith("OMP"))
1791       ORE.emit([&]() {
1792         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
1793                << " [" << RemarkName << "]";
1794       });
1795     else
1796       ORE.emit(
1797           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
1798   }
1799 
1800   /// RAII struct to temporarily change an RTL function's linkage to external.
1801   /// This prevents it from being mistakenly removed by other optimizations.
1802   struct ExternalizationRAII {
1803     ExternalizationRAII(OMPInformationCache &OMPInfoCache,
1804                         RuntimeFunction RFKind)
1805         : Declaration(OMPInfoCache.RFIs[RFKind].Declaration) {
1806       if (!Declaration)
1807         return;
1808 
1809       LinkageType = Declaration->getLinkage();
1810       Declaration->setLinkage(GlobalValue::ExternalLinkage);
1811     }
1812 
1813     ~ExternalizationRAII() {
1814       if (!Declaration)
1815         return;
1816 
1817       Declaration->setLinkage(LinkageType);
1818     }
1819 
1820     Function *Declaration;
1821     GlobalValue::LinkageTypes LinkageType;
1822   };
1823 
1824   /// The underlying module.
1825   Module &M;
1826 
1827   /// The SCC we are operating on.
1828   SmallVectorImpl<Function *> &SCC;
1829 
1830   /// Callback to update the call graph, the first argument is a removed call,
1831   /// the second an optional replacement call.
1832   CallGraphUpdater &CGUpdater;
1833 
1834   /// Callback to get an OptimizationRemarkEmitter from a Function *
1835   OptimizationRemarkGetter OREGetter;
1836 
1837   /// OpenMP-specific information cache. Also Used for Attributor runs.
1838   OMPInformationCache &OMPInfoCache;
1839 
1840   /// Attributor instance.
1841   Attributor &A;
1842 
1843   /// Helper function to run Attributor on SCC.
1844   bool runAttributor(bool IsModulePass) {
1845     if (SCC.empty())
1846       return false;
1847 
1848     // Temporarily make these function have external linkage so the Attributor
1849     // doesn't remove them when we try to look them up later.
1850     ExternalizationRAII Parallel(OMPInfoCache, OMPRTL___kmpc_kernel_parallel);
1851     ExternalizationRAII EndParallel(OMPInfoCache,
1852                                     OMPRTL___kmpc_kernel_end_parallel);
1853     ExternalizationRAII BarrierSPMD(OMPInfoCache,
1854                                     OMPRTL___kmpc_barrier_simple_spmd);
1855 
1856     registerAAs(IsModulePass);
1857 
1858     ChangeStatus Changed = A.run();
1859 
1860     LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
1861                       << " functions, result: " << Changed << ".\n");
1862 
1863     return Changed == ChangeStatus::CHANGED;
1864   }
1865 
1866   void registerFoldRuntimeCall(RuntimeFunction RF);
1867 
1868   /// Populate the Attributor with abstract attribute opportunities in the
1869   /// function.
1870   void registerAAs(bool IsModulePass);
1871 };
1872 
1873 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
1874   if (!OMPInfoCache.ModuleSlice.count(&F))
1875     return nullptr;
1876 
1877   // Use a scope to keep the lifetime of the CachedKernel short.
1878   {
1879     Optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
1880     if (CachedKernel)
1881       return *CachedKernel;
1882 
1883     // TODO: We should use an AA to create an (optimistic and callback
1884     //       call-aware) call graph. For now we stick to simple patterns that
1885     //       are less powerful, basically the worst fixpoint.
1886     if (isKernel(F)) {
1887       CachedKernel = Kernel(&F);
1888       return *CachedKernel;
1889     }
1890 
1891     CachedKernel = nullptr;
1892     if (!F.hasLocalLinkage()) {
1893 
1894       // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
1895       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1896         return ORA << "Potentially unknown OpenMP target region caller.";
1897       };
1898       emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
1899 
1900       return nullptr;
1901     }
1902   }
1903 
1904   auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
1905     if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
1906       // Allow use in equality comparisons.
1907       if (Cmp->isEquality())
1908         return getUniqueKernelFor(*Cmp);
1909       return nullptr;
1910     }
1911     if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
1912       // Allow direct calls.
1913       if (CB->isCallee(&U))
1914         return getUniqueKernelFor(*CB);
1915 
1916       OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1917           OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1918       // Allow the use in __kmpc_parallel_51 calls.
1919       if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
1920         return getUniqueKernelFor(*CB);
1921       return nullptr;
1922     }
1923     // Disallow every other use.
1924     return nullptr;
1925   };
1926 
1927   // TODO: In the future we want to track more than just a unique kernel.
1928   SmallPtrSet<Kernel, 2> PotentialKernels;
1929   OMPInformationCache::foreachUse(F, [&](const Use &U) {
1930     PotentialKernels.insert(GetUniqueKernelForUse(U));
1931   });
1932 
1933   Kernel K = nullptr;
1934   if (PotentialKernels.size() == 1)
1935     K = *PotentialKernels.begin();
1936 
1937   // Cache the result.
1938   UniqueKernelMap[&F] = K;
1939 
1940   return K;
1941 }
1942 
1943 bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
1944   OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1945       OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1946 
1947   bool Changed = false;
1948   if (!KernelParallelRFI)
1949     return Changed;
1950 
1951   // If we have disabled state machine changes, exit
1952   if (DisableOpenMPOptStateMachineRewrite)
1953     return Changed;
1954 
1955   for (Function *F : SCC) {
1956 
1957     // Check if the function is a use in a __kmpc_parallel_51 call at
1958     // all.
1959     bool UnknownUse = false;
1960     bool KernelParallelUse = false;
1961     unsigned NumDirectCalls = 0;
1962 
1963     SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
1964     OMPInformationCache::foreachUse(*F, [&](Use &U) {
1965       if (auto *CB = dyn_cast<CallBase>(U.getUser()))
1966         if (CB->isCallee(&U)) {
1967           ++NumDirectCalls;
1968           return;
1969         }
1970 
1971       if (isa<ICmpInst>(U.getUser())) {
1972         ToBeReplacedStateMachineUses.push_back(&U);
1973         return;
1974       }
1975 
1976       // Find wrapper functions that represent parallel kernels.
1977       CallInst *CI =
1978           OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
1979       const unsigned int WrapperFunctionArgNo = 6;
1980       if (!KernelParallelUse && CI &&
1981           CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
1982         KernelParallelUse = true;
1983         ToBeReplacedStateMachineUses.push_back(&U);
1984         return;
1985       }
1986       UnknownUse = true;
1987     });
1988 
1989     // Do not emit a remark if we haven't seen a __kmpc_parallel_51
1990     // use.
1991     if (!KernelParallelUse)
1992       continue;
1993 
1994     // If this ever hits, we should investigate.
1995     // TODO: Checking the number of uses is not a necessary restriction and
1996     // should be lifted.
1997     if (UnknownUse || NumDirectCalls != 1 ||
1998         ToBeReplacedStateMachineUses.size() > 2) {
1999       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2000         return ORA << "Parallel region is used in "
2001                    << (UnknownUse ? "unknown" : "unexpected")
2002                    << " ways. Will not attempt to rewrite the state machine.";
2003       };
2004       emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2005       continue;
2006     }
2007 
2008     // Even if we have __kmpc_parallel_51 calls, we (for now) give
2009     // up if the function is not called from a unique kernel.
2010     Kernel K = getUniqueKernelFor(*F);
2011     if (!K) {
2012       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2013         return ORA << "Parallel region is not called from a unique kernel. "
2014                       "Will not attempt to rewrite the state machine.";
2015       };
2016       emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2017       continue;
2018     }
2019 
2020     // We now know F is a parallel body function called only from the kernel K.
2021     // We also identified the state machine uses in which we replace the
2022     // function pointer by a new global symbol for identification purposes. This
2023     // ensures only direct calls to the function are left.
2024 
2025     Module &M = *F->getParent();
2026     Type *Int8Ty = Type::getInt8Ty(M.getContext());
2027 
2028     auto *ID = new GlobalVariable(
2029         M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2030         UndefValue::get(Int8Ty), F->getName() + ".ID");
2031 
2032     for (Use *U : ToBeReplacedStateMachineUses)
2033       U->set(ConstantExpr::getBitCast(ID, U->get()->getType()));
2034 
2035     ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2036 
2037     Changed = true;
2038   }
2039 
2040   return Changed;
2041 }
2042 
2043 /// Abstract Attribute for tracking ICV values.
2044 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2045   using Base = StateWrapper<BooleanState, AbstractAttribute>;
2046   AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2047 
2048   void initialize(Attributor &A) override {
2049     Function *F = getAnchorScope();
2050     if (!F || !A.isFunctionIPOAmendable(*F))
2051       indicatePessimisticFixpoint();
2052   }
2053 
2054   /// Returns true if value is assumed to be tracked.
2055   bool isAssumedTracked() const { return getAssumed(); }
2056 
2057   /// Returns true if value is known to be tracked.
2058   bool isKnownTracked() const { return getAssumed(); }
2059 
2060   /// Create an abstract attribute biew for the position \p IRP.
2061   static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2062 
2063   /// Return the value with which \p I can be replaced for specific \p ICV.
2064   virtual Optional<Value *> getReplacementValue(InternalControlVar ICV,
2065                                                 const Instruction *I,
2066                                                 Attributor &A) const {
2067     return None;
2068   }
2069 
2070   /// Return an assumed unique ICV value if a single candidate is found. If
2071   /// there cannot be one, return a nullptr. If it is not clear yet, return the
2072   /// Optional::NoneType.
2073   virtual Optional<Value *>
2074   getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2075 
2076   // Currently only nthreads is being tracked.
2077   // this array will only grow with time.
2078   InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2079 
2080   /// See AbstractAttribute::getName()
2081   const std::string getName() const override { return "AAICVTracker"; }
2082 
2083   /// See AbstractAttribute::getIdAddr()
2084   const char *getIdAddr() const override { return &ID; }
2085 
2086   /// This function should return true if the type of the \p AA is AAICVTracker
2087   static bool classof(const AbstractAttribute *AA) {
2088     return (AA->getIdAddr() == &ID);
2089   }
2090 
2091   static const char ID;
2092 };
2093 
2094 struct AAICVTrackerFunction : public AAICVTracker {
2095   AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2096       : AAICVTracker(IRP, A) {}
2097 
2098   // FIXME: come up with better string.
2099   const std::string getAsStr() const override { return "ICVTrackerFunction"; }
2100 
2101   // FIXME: come up with some stats.
2102   void trackStatistics() const override {}
2103 
2104   /// We don't manifest anything for this AA.
2105   ChangeStatus manifest(Attributor &A) override {
2106     return ChangeStatus::UNCHANGED;
2107   }
2108 
2109   // Map of ICV to their values at specific program point.
2110   EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
2111                   InternalControlVar::ICV___last>
2112       ICVReplacementValuesMap;
2113 
2114   ChangeStatus updateImpl(Attributor &A) override {
2115     ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2116 
2117     Function *F = getAnchorScope();
2118 
2119     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2120 
2121     for (InternalControlVar ICV : TrackableICVs) {
2122       auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2123 
2124       auto &ValuesMap = ICVReplacementValuesMap[ICV];
2125       auto TrackValues = [&](Use &U, Function &) {
2126         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2127         if (!CI)
2128           return false;
2129 
2130         // FIXME: handle setters with more that 1 arguments.
2131         /// Track new value.
2132         if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2133           HasChanged = ChangeStatus::CHANGED;
2134 
2135         return false;
2136       };
2137 
2138       auto CallCheck = [&](Instruction &I) {
2139         Optional<Value *> ReplVal = getValueForCall(A, &I, ICV);
2140         if (ReplVal.hasValue() &&
2141             ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2142           HasChanged = ChangeStatus::CHANGED;
2143 
2144         return true;
2145       };
2146 
2147       // Track all changes of an ICV.
2148       SetterRFI.foreachUse(TrackValues, F);
2149 
2150       bool UsedAssumedInformation = false;
2151       A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2152                                 UsedAssumedInformation,
2153                                 /* CheckBBLivenessOnly */ true);
2154 
2155       /// TODO: Figure out a way to avoid adding entry in
2156       /// ICVReplacementValuesMap
2157       Instruction *Entry = &F->getEntryBlock().front();
2158       if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2159         ValuesMap.insert(std::make_pair(Entry, nullptr));
2160     }
2161 
2162     return HasChanged;
2163   }
2164 
2165   /// Hepler to check if \p I is a call and get the value for it if it is
2166   /// unique.
2167   Optional<Value *> getValueForCall(Attributor &A, const Instruction *I,
2168                                     InternalControlVar &ICV) const {
2169 
2170     const auto *CB = dyn_cast<CallBase>(I);
2171     if (!CB || CB->hasFnAttr("no_openmp") ||
2172         CB->hasFnAttr("no_openmp_routines"))
2173       return None;
2174 
2175     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2176     auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2177     auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2178     Function *CalledFunction = CB->getCalledFunction();
2179 
2180     // Indirect call, assume ICV changes.
2181     if (CalledFunction == nullptr)
2182       return nullptr;
2183     if (CalledFunction == GetterRFI.Declaration)
2184       return None;
2185     if (CalledFunction == SetterRFI.Declaration) {
2186       if (ICVReplacementValuesMap[ICV].count(I))
2187         return ICVReplacementValuesMap[ICV].lookup(I);
2188 
2189       return nullptr;
2190     }
2191 
2192     // Since we don't know, assume it changes the ICV.
2193     if (CalledFunction->isDeclaration())
2194       return nullptr;
2195 
2196     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2197         *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2198 
2199     if (ICVTrackingAA.isAssumedTracked())
2200       return ICVTrackingAA.getUniqueReplacementValue(ICV);
2201 
2202     // If we don't know, assume it changes.
2203     return nullptr;
2204   }
2205 
2206   // We don't check unique value for a function, so return None.
2207   Optional<Value *>
2208   getUniqueReplacementValue(InternalControlVar ICV) const override {
2209     return None;
2210   }
2211 
2212   /// Return the value with which \p I can be replaced for specific \p ICV.
2213   Optional<Value *> getReplacementValue(InternalControlVar ICV,
2214                                         const Instruction *I,
2215                                         Attributor &A) const override {
2216     const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2217     if (ValuesMap.count(I))
2218       return ValuesMap.lookup(I);
2219 
2220     SmallVector<const Instruction *, 16> Worklist;
2221     SmallPtrSet<const Instruction *, 16> Visited;
2222     Worklist.push_back(I);
2223 
2224     Optional<Value *> ReplVal;
2225 
2226     while (!Worklist.empty()) {
2227       const Instruction *CurrInst = Worklist.pop_back_val();
2228       if (!Visited.insert(CurrInst).second)
2229         continue;
2230 
2231       const BasicBlock *CurrBB = CurrInst->getParent();
2232 
2233       // Go up and look for all potential setters/calls that might change the
2234       // ICV.
2235       while ((CurrInst = CurrInst->getPrevNode())) {
2236         if (ValuesMap.count(CurrInst)) {
2237           Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2238           // Unknown value, track new.
2239           if (!ReplVal.hasValue()) {
2240             ReplVal = NewReplVal;
2241             break;
2242           }
2243 
2244           // If we found a new value, we can't know the icv value anymore.
2245           if (NewReplVal.hasValue())
2246             if (ReplVal != NewReplVal)
2247               return nullptr;
2248 
2249           break;
2250         }
2251 
2252         Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV);
2253         if (!NewReplVal.hasValue())
2254           continue;
2255 
2256         // Unknown value, track new.
2257         if (!ReplVal.hasValue()) {
2258           ReplVal = NewReplVal;
2259           break;
2260         }
2261 
2262         // if (NewReplVal.hasValue())
2263         // We found a new value, we can't know the icv value anymore.
2264         if (ReplVal != NewReplVal)
2265           return nullptr;
2266       }
2267 
2268       // If we are in the same BB and we have a value, we are done.
2269       if (CurrBB == I->getParent() && ReplVal.hasValue())
2270         return ReplVal;
2271 
2272       // Go through all predecessors and add terminators for analysis.
2273       for (const BasicBlock *Pred : predecessors(CurrBB))
2274         if (const Instruction *Terminator = Pred->getTerminator())
2275           Worklist.push_back(Terminator);
2276     }
2277 
2278     return ReplVal;
2279   }
2280 };
2281 
2282 struct AAICVTrackerFunctionReturned : AAICVTracker {
2283   AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2284       : AAICVTracker(IRP, A) {}
2285 
2286   // FIXME: come up with better string.
2287   const std::string getAsStr() const override {
2288     return "ICVTrackerFunctionReturned";
2289   }
2290 
2291   // FIXME: come up with some stats.
2292   void trackStatistics() const override {}
2293 
2294   /// We don't manifest anything for this AA.
2295   ChangeStatus manifest(Attributor &A) override {
2296     return ChangeStatus::UNCHANGED;
2297   }
2298 
2299   // Map of ICV to their values at specific program point.
2300   EnumeratedArray<Optional<Value *>, InternalControlVar,
2301                   InternalControlVar::ICV___last>
2302       ICVReplacementValuesMap;
2303 
2304   /// Return the value with which \p I can be replaced for specific \p ICV.
2305   Optional<Value *>
2306   getUniqueReplacementValue(InternalControlVar ICV) const override {
2307     return ICVReplacementValuesMap[ICV];
2308   }
2309 
2310   ChangeStatus updateImpl(Attributor &A) override {
2311     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2312     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2313         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2314 
2315     if (!ICVTrackingAA.isAssumedTracked())
2316       return indicatePessimisticFixpoint();
2317 
2318     for (InternalControlVar ICV : TrackableICVs) {
2319       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2320       Optional<Value *> UniqueICVValue;
2321 
2322       auto CheckReturnInst = [&](Instruction &I) {
2323         Optional<Value *> NewReplVal =
2324             ICVTrackingAA.getReplacementValue(ICV, &I, A);
2325 
2326         // If we found a second ICV value there is no unique returned value.
2327         if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal)
2328           return false;
2329 
2330         UniqueICVValue = NewReplVal;
2331 
2332         return true;
2333       };
2334 
2335       bool UsedAssumedInformation = false;
2336       if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2337                                      UsedAssumedInformation,
2338                                      /* CheckBBLivenessOnly */ true))
2339         UniqueICVValue = nullptr;
2340 
2341       if (UniqueICVValue == ReplVal)
2342         continue;
2343 
2344       ReplVal = UniqueICVValue;
2345       Changed = ChangeStatus::CHANGED;
2346     }
2347 
2348     return Changed;
2349   }
2350 };
2351 
2352 struct AAICVTrackerCallSite : AAICVTracker {
2353   AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2354       : AAICVTracker(IRP, A) {}
2355 
2356   void initialize(Attributor &A) override {
2357     Function *F = getAnchorScope();
2358     if (!F || !A.isFunctionIPOAmendable(*F))
2359       indicatePessimisticFixpoint();
2360 
2361     // We only initialize this AA for getters, so we need to know which ICV it
2362     // gets.
2363     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2364     for (InternalControlVar ICV : TrackableICVs) {
2365       auto ICVInfo = OMPInfoCache.ICVs[ICV];
2366       auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2367       if (Getter.Declaration == getAssociatedFunction()) {
2368         AssociatedICV = ICVInfo.Kind;
2369         return;
2370       }
2371     }
2372 
2373     /// Unknown ICV.
2374     indicatePessimisticFixpoint();
2375   }
2376 
2377   ChangeStatus manifest(Attributor &A) override {
2378     if (!ReplVal.hasValue() || !ReplVal.getValue())
2379       return ChangeStatus::UNCHANGED;
2380 
2381     A.changeValueAfterManifest(*getCtxI(), **ReplVal);
2382     A.deleteAfterManifest(*getCtxI());
2383 
2384     return ChangeStatus::CHANGED;
2385   }
2386 
2387   // FIXME: come up with better string.
2388   const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
2389 
2390   // FIXME: come up with some stats.
2391   void trackStatistics() const override {}
2392 
2393   InternalControlVar AssociatedICV;
2394   Optional<Value *> ReplVal;
2395 
2396   ChangeStatus updateImpl(Attributor &A) override {
2397     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2398         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2399 
2400     // We don't have any information, so we assume it changes the ICV.
2401     if (!ICVTrackingAA.isAssumedTracked())
2402       return indicatePessimisticFixpoint();
2403 
2404     Optional<Value *> NewReplVal =
2405         ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
2406 
2407     if (ReplVal == NewReplVal)
2408       return ChangeStatus::UNCHANGED;
2409 
2410     ReplVal = NewReplVal;
2411     return ChangeStatus::CHANGED;
2412   }
2413 
2414   // Return the value with which associated value can be replaced for specific
2415   // \p ICV.
2416   Optional<Value *>
2417   getUniqueReplacementValue(InternalControlVar ICV) const override {
2418     return ReplVal;
2419   }
2420 };
2421 
2422 struct AAICVTrackerCallSiteReturned : AAICVTracker {
2423   AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2424       : AAICVTracker(IRP, A) {}
2425 
2426   // FIXME: come up with better string.
2427   const std::string getAsStr() const override {
2428     return "ICVTrackerCallSiteReturned";
2429   }
2430 
2431   // FIXME: come up with some stats.
2432   void trackStatistics() const override {}
2433 
2434   /// We don't manifest anything for this AA.
2435   ChangeStatus manifest(Attributor &A) override {
2436     return ChangeStatus::UNCHANGED;
2437   }
2438 
2439   // Map of ICV to their values at specific program point.
2440   EnumeratedArray<Optional<Value *>, InternalControlVar,
2441                   InternalControlVar::ICV___last>
2442       ICVReplacementValuesMap;
2443 
2444   /// Return the value with which associated value can be replaced for specific
2445   /// \p ICV.
2446   Optional<Value *>
2447   getUniqueReplacementValue(InternalControlVar ICV) const override {
2448     return ICVReplacementValuesMap[ICV];
2449   }
2450 
2451   ChangeStatus updateImpl(Attributor &A) override {
2452     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2453     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2454         *this, IRPosition::returned(*getAssociatedFunction()),
2455         DepClassTy::REQUIRED);
2456 
2457     // We don't have any information, so we assume it changes the ICV.
2458     if (!ICVTrackingAA.isAssumedTracked())
2459       return indicatePessimisticFixpoint();
2460 
2461     for (InternalControlVar ICV : TrackableICVs) {
2462       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2463       Optional<Value *> NewReplVal =
2464           ICVTrackingAA.getUniqueReplacementValue(ICV);
2465 
2466       if (ReplVal == NewReplVal)
2467         continue;
2468 
2469       ReplVal = NewReplVal;
2470       Changed = ChangeStatus::CHANGED;
2471     }
2472     return Changed;
2473   }
2474 };
2475 
2476 struct AAExecutionDomainFunction : public AAExecutionDomain {
2477   AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2478       : AAExecutionDomain(IRP, A) {}
2479 
2480   const std::string getAsStr() const override {
2481     return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) +
2482            "/" + std::to_string(NumBBs) + " BBs thread 0 only.";
2483   }
2484 
2485   /// See AbstractAttribute::trackStatistics().
2486   void trackStatistics() const override {}
2487 
2488   void initialize(Attributor &A) override {
2489     Function *F = getAnchorScope();
2490     for (const auto &BB : *F)
2491       SingleThreadedBBs.insert(&BB);
2492     NumBBs = SingleThreadedBBs.size();
2493   }
2494 
2495   ChangeStatus manifest(Attributor &A) override {
2496     LLVM_DEBUG({
2497       for (const BasicBlock *BB : SingleThreadedBBs)
2498         dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2499                << BB->getName() << " is executed by a single thread.\n";
2500     });
2501     return ChangeStatus::UNCHANGED;
2502   }
2503 
2504   ChangeStatus updateImpl(Attributor &A) override;
2505 
2506   /// Check if an instruction is executed by a single thread.
2507   bool isExecutedByInitialThreadOnly(const Instruction &I) const override {
2508     return isExecutedByInitialThreadOnly(*I.getParent());
2509   }
2510 
2511   bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2512     return isValidState() && SingleThreadedBBs.contains(&BB);
2513   }
2514 
2515   /// Set of basic blocks that are executed by a single thread.
2516   DenseSet<const BasicBlock *> SingleThreadedBBs;
2517 
2518   /// Total number of basic blocks in this function.
2519   long unsigned NumBBs;
2520 };
2521 
2522 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
2523   Function *F = getAnchorScope();
2524   ReversePostOrderTraversal<Function *> RPOT(F);
2525   auto NumSingleThreadedBBs = SingleThreadedBBs.size();
2526 
2527   bool AllCallSitesKnown;
2528   auto PredForCallSite = [&](AbstractCallSite ACS) {
2529     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
2530         *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
2531         DepClassTy::REQUIRED);
2532     return ACS.isDirectCall() &&
2533            ExecutionDomainAA.isExecutedByInitialThreadOnly(
2534                *ACS.getInstruction());
2535   };
2536 
2537   if (!A.checkForAllCallSites(PredForCallSite, *this,
2538                               /* RequiresAllCallSites */ true,
2539                               AllCallSitesKnown))
2540     SingleThreadedBBs.erase(&F->getEntryBlock());
2541 
2542   auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2543   auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2544 
2545   // Check if the edge into the successor block compares the __kmpc_target_init
2546   // result with -1. If we are in non-SPMD-mode that signals only the main
2547   // thread will execute the edge.
2548   auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
2549     if (!Edge || !Edge->isConditional())
2550       return false;
2551     if (Edge->getSuccessor(0) != SuccessorBB)
2552       return false;
2553 
2554     auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2555     if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2556       return false;
2557 
2558     ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2559     if (!C)
2560       return false;
2561 
2562     // Match:  -1 == __kmpc_target_init (for non-SPMD kernels only!)
2563     if (C->isAllOnesValue()) {
2564       auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2565       CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2566       if (!CB)
2567         return false;
2568       const int InitIsSPMDArgNo = 1;
2569       auto *IsSPMDModeCI =
2570           dyn_cast<ConstantInt>(CB->getOperand(InitIsSPMDArgNo));
2571       return IsSPMDModeCI && IsSPMDModeCI->isZero();
2572     }
2573 
2574     return false;
2575   };
2576 
2577   // Merge all the predecessor states into the current basic block. A basic
2578   // block is executed by a single thread if all of its predecessors are.
2579   auto MergePredecessorStates = [&](BasicBlock *BB) {
2580     if (pred_begin(BB) == pred_end(BB))
2581       return SingleThreadedBBs.contains(BB);
2582 
2583     bool IsInitialThread = true;
2584     for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB);
2585          PredBB != PredEndBB; ++PredBB) {
2586       if (!IsInitialThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()),
2587                                BB))
2588         IsInitialThread &= SingleThreadedBBs.contains(*PredBB);
2589     }
2590 
2591     return IsInitialThread;
2592   };
2593 
2594   for (auto *BB : RPOT) {
2595     if (!MergePredecessorStates(BB))
2596       SingleThreadedBBs.erase(BB);
2597   }
2598 
2599   return (NumSingleThreadedBBs == SingleThreadedBBs.size())
2600              ? ChangeStatus::UNCHANGED
2601              : ChangeStatus::CHANGED;
2602 }
2603 
2604 /// Try to replace memory allocation calls called by a single thread with a
2605 /// static buffer of shared memory.
2606 struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
2607   using Base = StateWrapper<BooleanState, AbstractAttribute>;
2608   AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2609 
2610   /// Create an abstract attribute view for the position \p IRP.
2611   static AAHeapToShared &createForPosition(const IRPosition &IRP,
2612                                            Attributor &A);
2613 
2614   /// Returns true if HeapToShared conversion is assumed to be possible.
2615   virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
2616 
2617   /// Returns true if HeapToShared conversion is assumed and the CB is a
2618   /// callsite to a free operation to be removed.
2619   virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
2620 
2621   /// See AbstractAttribute::getName().
2622   const std::string getName() const override { return "AAHeapToShared"; }
2623 
2624   /// See AbstractAttribute::getIdAddr().
2625   const char *getIdAddr() const override { return &ID; }
2626 
2627   /// This function should return true if the type of the \p AA is
2628   /// AAHeapToShared.
2629   static bool classof(const AbstractAttribute *AA) {
2630     return (AA->getIdAddr() == &ID);
2631   }
2632 
2633   /// Unique ID (due to the unique address)
2634   static const char ID;
2635 };
2636 
2637 struct AAHeapToSharedFunction : public AAHeapToShared {
2638   AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
2639       : AAHeapToShared(IRP, A) {}
2640 
2641   const std::string getAsStr() const override {
2642     return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
2643            " malloc calls eligible.";
2644   }
2645 
2646   /// See AbstractAttribute::trackStatistics().
2647   void trackStatistics() const override {}
2648 
2649   /// This functions finds free calls that will be removed by the
2650   /// HeapToShared transformation.
2651   void findPotentialRemovedFreeCalls(Attributor &A) {
2652     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2653     auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2654 
2655     PotentialRemovedFreeCalls.clear();
2656     // Update free call users of found malloc calls.
2657     for (CallBase *CB : MallocCalls) {
2658       SmallVector<CallBase *, 4> FreeCalls;
2659       for (auto *U : CB->users()) {
2660         CallBase *C = dyn_cast<CallBase>(U);
2661         if (C && C->getCalledFunction() == FreeRFI.Declaration)
2662           FreeCalls.push_back(C);
2663       }
2664 
2665       if (FreeCalls.size() != 1)
2666         continue;
2667 
2668       PotentialRemovedFreeCalls.insert(FreeCalls.front());
2669     }
2670   }
2671 
2672   void initialize(Attributor &A) override {
2673     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2674     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2675 
2676     for (User *U : RFI.Declaration->users())
2677       if (CallBase *CB = dyn_cast<CallBase>(U))
2678         MallocCalls.insert(CB);
2679 
2680     findPotentialRemovedFreeCalls(A);
2681   }
2682 
2683   bool isAssumedHeapToShared(CallBase &CB) const override {
2684     return isValidState() && MallocCalls.count(&CB);
2685   }
2686 
2687   bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
2688     return isValidState() && PotentialRemovedFreeCalls.count(&CB);
2689   }
2690 
2691   ChangeStatus manifest(Attributor &A) override {
2692     if (MallocCalls.empty())
2693       return ChangeStatus::UNCHANGED;
2694 
2695     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2696     auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2697 
2698     Function *F = getAnchorScope();
2699     auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
2700                                             DepClassTy::OPTIONAL);
2701 
2702     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2703     for (CallBase *CB : MallocCalls) {
2704       // Skip replacing this if HeapToStack has already claimed it.
2705       if (HS && HS->isAssumedHeapToStack(*CB))
2706         continue;
2707 
2708       // Find the unique free call to remove it.
2709       SmallVector<CallBase *, 4> FreeCalls;
2710       for (auto *U : CB->users()) {
2711         CallBase *C = dyn_cast<CallBase>(U);
2712         if (C && C->getCalledFunction() == FreeCall.Declaration)
2713           FreeCalls.push_back(C);
2714       }
2715       if (FreeCalls.size() != 1)
2716         continue;
2717 
2718       ConstantInt *AllocSize = dyn_cast<ConstantInt>(CB->getArgOperand(0));
2719 
2720       LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
2721                         << " with " << AllocSize->getZExtValue()
2722                         << " bytes of shared memory\n");
2723 
2724       // Create a new shared memory buffer of the same size as the allocation
2725       // and replace all the uses of the original allocation with it.
2726       Module *M = CB->getModule();
2727       Type *Int8Ty = Type::getInt8Ty(M->getContext());
2728       Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
2729       auto *SharedMem = new GlobalVariable(
2730           *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
2731           UndefValue::get(Int8ArrTy), CB->getName(), nullptr,
2732           GlobalValue::NotThreadLocal,
2733           static_cast<unsigned>(AddressSpace::Shared));
2734       auto *NewBuffer =
2735           ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
2736 
2737       auto Remark = [&](OptimizationRemark OR) {
2738         return OR << "Replaced globalized variable with "
2739                   << ore::NV("SharedMemory", AllocSize->getZExtValue())
2740                   << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ")
2741                   << "of shared memory.";
2742       };
2743       A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
2744 
2745       SharedMem->setAlignment(MaybeAlign(32));
2746 
2747       A.changeValueAfterManifest(*CB, *NewBuffer);
2748       A.deleteAfterManifest(*CB);
2749       A.deleteAfterManifest(*FreeCalls.front());
2750 
2751       NumBytesMovedToSharedMemory += AllocSize->getZExtValue();
2752       Changed = ChangeStatus::CHANGED;
2753     }
2754 
2755     return Changed;
2756   }
2757 
2758   ChangeStatus updateImpl(Attributor &A) override {
2759     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2760     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2761     Function *F = getAnchorScope();
2762 
2763     auto NumMallocCalls = MallocCalls.size();
2764 
2765     // Only consider malloc calls executed by a single thread with a constant.
2766     for (User *U : RFI.Declaration->users()) {
2767       const auto &ED = A.getAAFor<AAExecutionDomain>(
2768           *this, IRPosition::function(*F), DepClassTy::REQUIRED);
2769       if (CallBase *CB = dyn_cast<CallBase>(U))
2770         if (!dyn_cast<ConstantInt>(CB->getArgOperand(0)) ||
2771             !ED.isExecutedByInitialThreadOnly(*CB))
2772           MallocCalls.erase(CB);
2773     }
2774 
2775     findPotentialRemovedFreeCalls(A);
2776 
2777     if (NumMallocCalls != MallocCalls.size())
2778       return ChangeStatus::CHANGED;
2779 
2780     return ChangeStatus::UNCHANGED;
2781   }
2782 
2783   /// Collection of all malloc calls in a function.
2784   SmallPtrSet<CallBase *, 4> MallocCalls;
2785   /// Collection of potentially removed free calls in a function.
2786   SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
2787 };
2788 
2789 struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
2790   using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
2791   AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2792 
2793   /// Statistics are tracked as part of manifest for now.
2794   void trackStatistics() const override {}
2795 
2796   /// See AbstractAttribute::getAsStr()
2797   const std::string getAsStr() const override {
2798     if (!isValidState())
2799       return "<invalid>";
2800     return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
2801                                                             : "generic") +
2802            std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
2803                                                                : "") +
2804            std::string(" #PRs: ") +
2805            std::to_string(ReachedKnownParallelRegions.size()) +
2806            ", #Unknown PRs: " +
2807            std::to_string(ReachedUnknownParallelRegions.size());
2808   }
2809 
2810   /// Create an abstract attribute biew for the position \p IRP.
2811   static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
2812 
2813   /// See AbstractAttribute::getName()
2814   const std::string getName() const override { return "AAKernelInfo"; }
2815 
2816   /// See AbstractAttribute::getIdAddr()
2817   const char *getIdAddr() const override { return &ID; }
2818 
2819   /// This function should return true if the type of the \p AA is AAKernelInfo
2820   static bool classof(const AbstractAttribute *AA) {
2821     return (AA->getIdAddr() == &ID);
2822   }
2823 
2824   static const char ID;
2825 };
2826 
2827 /// The function kernel info abstract attribute, basically, what can we say
2828 /// about a function with regards to the KernelInfoState.
2829 struct AAKernelInfoFunction : AAKernelInfo {
2830   AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
2831       : AAKernelInfo(IRP, A) {}
2832 
2833   SmallPtrSet<Instruction *, 4> GuardedInstructions;
2834 
2835   SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
2836     return GuardedInstructions;
2837   }
2838 
2839   /// See AbstractAttribute::initialize(...).
2840   void initialize(Attributor &A) override {
2841     // This is a high-level transform that might change the constant arguments
2842     // of the init and dinit calls. We need to tell the Attributor about this
2843     // to avoid other parts using the current constant value for simpliication.
2844     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2845 
2846     Function *Fn = getAnchorScope();
2847     if (!OMPInfoCache.Kernels.count(Fn))
2848       return;
2849 
2850     // Add itself to the reaching kernel and set IsKernelEntry.
2851     ReachingKernelEntries.insert(Fn);
2852     IsKernelEntry = true;
2853 
2854     OMPInformationCache::RuntimeFunctionInfo &InitRFI =
2855         OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2856     OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
2857         OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
2858 
2859     // For kernels we perform more initialization work, first we find the init
2860     // and deinit calls.
2861     auto StoreCallBase = [](Use &U,
2862                             OMPInformationCache::RuntimeFunctionInfo &RFI,
2863                             CallBase *&Storage) {
2864       CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
2865       assert(CB &&
2866              "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
2867       assert(!Storage &&
2868              "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
2869       Storage = CB;
2870       return false;
2871     };
2872     InitRFI.foreachUse(
2873         [&](Use &U, Function &) {
2874           StoreCallBase(U, InitRFI, KernelInitCB);
2875           return false;
2876         },
2877         Fn);
2878     DeinitRFI.foreachUse(
2879         [&](Use &U, Function &) {
2880           StoreCallBase(U, DeinitRFI, KernelDeinitCB);
2881           return false;
2882         },
2883         Fn);
2884 
2885     // Ignore kernels without initializers such as global constructors.
2886     if (!KernelInitCB || !KernelDeinitCB) {
2887       indicateOptimisticFixpoint();
2888       return;
2889     }
2890 
2891     // For kernels we might need to initialize/finalize the IsSPMD state and
2892     // we need to register a simplification callback so that the Attributor
2893     // knows the constant arguments to __kmpc_target_init and
2894     // __kmpc_target_deinit might actually change.
2895 
2896     Attributor::SimplifictionCallbackTy StateMachineSimplifyCB =
2897         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2898             bool &UsedAssumedInformation) -> Optional<Value *> {
2899       // IRP represents the "use generic state machine" argument of an
2900       // __kmpc_target_init call. We will answer this one with the internal
2901       // state. As long as we are not in an invalid state, we will create a
2902       // custom state machine so the value should be a `i1 false`. If we are
2903       // in an invalid state, we won't change the value that is in the IR.
2904       if (!isValidState())
2905         return nullptr;
2906       // If we have disabled state machine rewrites, don't make a custom one.
2907       if (DisableOpenMPOptStateMachineRewrite)
2908         return nullptr;
2909       if (AA)
2910         A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2911       UsedAssumedInformation = !isAtFixpoint();
2912       auto *FalseVal =
2913           ConstantInt::getBool(IRP.getAnchorValue().getContext(), 0);
2914       return FalseVal;
2915     };
2916 
2917     Attributor::SimplifictionCallbackTy IsSPMDModeSimplifyCB =
2918         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2919             bool &UsedAssumedInformation) -> Optional<Value *> {
2920       // IRP represents the "SPMDCompatibilityTracker" argument of an
2921       // __kmpc_target_init or
2922       // __kmpc_target_deinit call. We will answer this one with the internal
2923       // state.
2924       if (!SPMDCompatibilityTracker.isValidState())
2925         return nullptr;
2926       if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2927         if (AA)
2928           A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2929         UsedAssumedInformation = true;
2930       } else {
2931         UsedAssumedInformation = false;
2932       }
2933       auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
2934                                        SPMDCompatibilityTracker.isAssumed());
2935       return Val;
2936     };
2937 
2938     Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB =
2939         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2940             bool &UsedAssumedInformation) -> Optional<Value *> {
2941       // IRP represents the "RequiresFullRuntime" argument of an
2942       // __kmpc_target_init or __kmpc_target_deinit call. We will answer this
2943       // one with the internal state of the SPMDCompatibilityTracker, so if
2944       // generic then true, if SPMD then false.
2945       if (!SPMDCompatibilityTracker.isValidState())
2946         return nullptr;
2947       if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2948         if (AA)
2949           A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2950         UsedAssumedInformation = true;
2951       } else {
2952         UsedAssumedInformation = false;
2953       }
2954       auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
2955                                        !SPMDCompatibilityTracker.isAssumed());
2956       return Val;
2957     };
2958 
2959     constexpr const int InitIsSPMDArgNo = 1;
2960     constexpr const int DeinitIsSPMDArgNo = 1;
2961     constexpr const int InitUseStateMachineArgNo = 2;
2962     constexpr const int InitRequiresFullRuntimeArgNo = 3;
2963     constexpr const int DeinitRequiresFullRuntimeArgNo = 2;
2964     A.registerSimplificationCallback(
2965         IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo),
2966         StateMachineSimplifyCB);
2967     A.registerSimplificationCallback(
2968         IRPosition::callsite_argument(*KernelInitCB, InitIsSPMDArgNo),
2969         IsSPMDModeSimplifyCB);
2970     A.registerSimplificationCallback(
2971         IRPosition::callsite_argument(*KernelDeinitCB, DeinitIsSPMDArgNo),
2972         IsSPMDModeSimplifyCB);
2973     A.registerSimplificationCallback(
2974         IRPosition::callsite_argument(*KernelInitCB,
2975                                       InitRequiresFullRuntimeArgNo),
2976         IsGenericModeSimplifyCB);
2977     A.registerSimplificationCallback(
2978         IRPosition::callsite_argument(*KernelDeinitCB,
2979                                       DeinitRequiresFullRuntimeArgNo),
2980         IsGenericModeSimplifyCB);
2981 
2982     // Check if we know we are in SPMD-mode already.
2983     ConstantInt *IsSPMDArg =
2984         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo));
2985     if (IsSPMDArg && !IsSPMDArg->isZero())
2986       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
2987     // This is a generic region but SPMDization is disabled so stop tracking.
2988     else if (DisableOpenMPOptSPMDization)
2989       SPMDCompatibilityTracker.indicatePessimisticFixpoint();
2990   }
2991 
2992   /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
2993   /// finished now.
2994   ChangeStatus manifest(Attributor &A) override {
2995     // If we are not looking at a kernel with __kmpc_target_init and
2996     // __kmpc_target_deinit call we cannot actually manifest the information.
2997     if (!KernelInitCB || !KernelDeinitCB)
2998       return ChangeStatus::UNCHANGED;
2999 
3000     // Known SPMD-mode kernels need no manifest changes.
3001     if (SPMDCompatibilityTracker.isKnown())
3002       return ChangeStatus::UNCHANGED;
3003 
3004     // If we can we change the execution mode to SPMD-mode otherwise we build a
3005     // custom state machine.
3006     if (!changeToSPMDMode(A))
3007       buildCustomStateMachine(A);
3008 
3009     return ChangeStatus::CHANGED;
3010   }
3011 
3012   bool changeToSPMDMode(Attributor &A) {
3013     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3014 
3015     if (!SPMDCompatibilityTracker.isAssumed()) {
3016       for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
3017         if (!NonCompatibleI)
3018           continue;
3019 
3020         // Skip diagnostics on calls to known OpenMP runtime functions for now.
3021         if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
3022           if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
3023             continue;
3024 
3025         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3026           ORA << "Value has potential side effects preventing SPMD-mode "
3027                  "execution";
3028           if (isa<CallBase>(NonCompatibleI)) {
3029             ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
3030                    "the called function to override";
3031           }
3032           return ORA << ".";
3033         };
3034         A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
3035                                                  Remark);
3036 
3037         LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
3038                           << *NonCompatibleI << "\n");
3039       }
3040 
3041       return false;
3042     }
3043 
3044     auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3045                                    Instruction *RegionEndI) {
3046       LoopInfo *LI = nullptr;
3047       DominatorTree *DT = nullptr;
3048       MemorySSAUpdater *MSU = nullptr;
3049       using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3050 
3051       BasicBlock *ParentBB = RegionStartI->getParent();
3052       Function *Fn = ParentBB->getParent();
3053       Module &M = *Fn->getParent();
3054 
3055       // Create all the blocks and logic.
3056       // ParentBB:
3057       //    goto RegionCheckTidBB
3058       // RegionCheckTidBB:
3059       //    Tid = __kmpc_hardware_thread_id()
3060       //    if (Tid != 0)
3061       //        goto RegionBarrierBB
3062       // RegionStartBB:
3063       //    <execute instructions guarded>
3064       //    goto RegionEndBB
3065       // RegionEndBB:
3066       //    <store escaping values to shared mem>
3067       //    goto RegionBarrierBB
3068       //  RegionBarrierBB:
3069       //    __kmpc_simple_barrier_spmd()
3070       //    // second barrier is omitted if lacking escaping values.
3071       //    <load escaping values from shared mem>
3072       //    __kmpc_simple_barrier_spmd()
3073       //    goto RegionExitBB
3074       // RegionExitBB:
3075       //    <execute rest of instructions>
3076 
3077       BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3078                                            DT, LI, MSU, "region.guarded.end");
3079       BasicBlock *RegionBarrierBB =
3080           SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3081                      MSU, "region.barrier");
3082       BasicBlock *RegionExitBB =
3083           SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3084                      DT, LI, MSU, "region.exit");
3085       BasicBlock *RegionStartBB =
3086           SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3087 
3088       assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3089              "Expected a different CFG");
3090 
3091       BasicBlock *RegionCheckTidBB = SplitBlock(
3092           ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3093 
3094       // Register basic blocks with the Attributor.
3095       A.registerManifestAddedBasicBlock(*RegionEndBB);
3096       A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3097       A.registerManifestAddedBasicBlock(*RegionExitBB);
3098       A.registerManifestAddedBasicBlock(*RegionStartBB);
3099       A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
3100 
3101       bool HasBroadcastValues = false;
3102       // Find escaping outputs from the guarded region to outside users and
3103       // broadcast their values to them.
3104       for (Instruction &I : *RegionStartBB) {
3105         SmallPtrSet<Instruction *, 4> OutsideUsers;
3106         for (User *Usr : I.users()) {
3107           Instruction &UsrI = *cast<Instruction>(Usr);
3108           if (UsrI.getParent() != RegionStartBB)
3109             OutsideUsers.insert(&UsrI);
3110         }
3111 
3112         if (OutsideUsers.empty())
3113           continue;
3114 
3115         HasBroadcastValues = true;
3116 
3117         // Emit a global variable in shared memory to store the broadcasted
3118         // value.
3119         auto *SharedMem = new GlobalVariable(
3120             M, I.getType(), /* IsConstant */ false,
3121             GlobalValue::InternalLinkage, UndefValue::get(I.getType()),
3122             I.getName() + ".guarded.output.alloc", nullptr,
3123             GlobalValue::NotThreadLocal,
3124             static_cast<unsigned>(AddressSpace::Shared));
3125 
3126         // Emit a store instruction to update the value.
3127         new StoreInst(&I, SharedMem, RegionEndBB->getTerminator());
3128 
3129         LoadInst *LoadI = new LoadInst(I.getType(), SharedMem,
3130                                        I.getName() + ".guarded.output.load",
3131                                        RegionBarrierBB->getTerminator());
3132 
3133         // Emit a load instruction and replace uses of the output value.
3134         for (Instruction *UsrI : OutsideUsers) {
3135           assert(UsrI->getParent() == RegionExitBB &&
3136                  "Expected escaping users in exit region");
3137           UsrI->replaceUsesOfWith(&I, LoadI);
3138         }
3139       }
3140 
3141       auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3142 
3143       // Go to tid check BB in ParentBB.
3144       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
3145       ParentBB->getTerminator()->eraseFromParent();
3146       OpenMPIRBuilder::LocationDescription Loc(
3147           InsertPointTy(ParentBB, ParentBB->end()), DL);
3148       OMPInfoCache.OMPBuilder.updateToLocation(Loc);
3149       auto *SrcLocStr = OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc);
3150       Value *Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr);
3151       BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
3152 
3153       // Add check for Tid in RegionCheckTidBB
3154       RegionCheckTidBB->getTerminator()->eraseFromParent();
3155       OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
3156           InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
3157       OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
3158       FunctionCallee HardwareTidFn =
3159           OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3160               M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3161       Value *Tid =
3162           OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
3163       Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
3164       OMPInfoCache.OMPBuilder.Builder
3165           .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
3166           ->setDebugLoc(DL);
3167 
3168       // First barrier for synchronization, ensures main thread has updated
3169       // values.
3170       FunctionCallee BarrierFn =
3171           OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3172               M, OMPRTL___kmpc_barrier_simple_spmd);
3173       OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
3174           RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
3175       OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid})
3176           ->setDebugLoc(DL);
3177 
3178       // Second barrier ensures workers have read broadcast values.
3179       if (HasBroadcastValues)
3180         CallInst::Create(BarrierFn, {Ident, Tid}, "",
3181                          RegionBarrierBB->getTerminator())
3182             ->setDebugLoc(DL);
3183     };
3184 
3185     SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions;
3186 
3187     for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3188       BasicBlock *BB = GuardedI->getParent();
3189       auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
3190           IRPosition::function(*GuardedI->getFunction()), nullptr,
3191           DepClassTy::NONE);
3192       assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
3193       auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
3194       // Continue if instruction is already guarded.
3195       if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
3196         continue;
3197 
3198       Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
3199       for (Instruction &I : *BB) {
3200         // If instruction I needs to be guarded update the guarded region
3201         // bounds.
3202         if (SPMDCompatibilityTracker.contains(&I)) {
3203           CalleeAAFunction.getGuardedInstructions().insert(&I);
3204           if (GuardedRegionStart)
3205             GuardedRegionEnd = &I;
3206           else
3207             GuardedRegionStart = GuardedRegionEnd = &I;
3208 
3209           continue;
3210         }
3211 
3212         // Instruction I does not need guarding, store
3213         // any region found and reset bounds.
3214         if (GuardedRegionStart) {
3215           GuardedRegions.push_back(
3216               std::make_pair(GuardedRegionStart, GuardedRegionEnd));
3217           GuardedRegionStart = nullptr;
3218           GuardedRegionEnd = nullptr;
3219         }
3220       }
3221     }
3222 
3223     for (auto &GR : GuardedRegions)
3224       CreateGuardedRegion(GR.first, GR.second);
3225 
3226     // Adjust the global exec mode flag that tells the runtime what mode this
3227     // kernel is executed in.
3228     Function *Kernel = getAnchorScope();
3229     GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable(
3230         (Kernel->getName() + "_exec_mode").str());
3231     assert(ExecMode && "Kernel without exec mode?");
3232     assert(ExecMode->getInitializer() &&
3233            ExecMode->getInitializer()->isOneValue() &&
3234            "Initially non-SPMD kernel has SPMD exec mode!");
3235 
3236     // Set the global exec mode flag to indicate SPMD-Generic mode.
3237     constexpr int SPMDGeneric = 2;
3238     if (!ExecMode->getInitializer()->isZeroValue())
3239       ExecMode->setInitializer(
3240           ConstantInt::get(ExecMode->getInitializer()->getType(), SPMDGeneric));
3241 
3242     // Next rewrite the init and deinit calls to indicate we use SPMD-mode now.
3243     const int InitIsSPMDArgNo = 1;
3244     const int DeinitIsSPMDArgNo = 1;
3245     const int InitUseStateMachineArgNo = 2;
3246     const int InitRequiresFullRuntimeArgNo = 3;
3247     const int DeinitRequiresFullRuntimeArgNo = 2;
3248 
3249     auto &Ctx = getAnchorValue().getContext();
3250     A.changeUseAfterManifest(KernelInitCB->getArgOperandUse(InitIsSPMDArgNo),
3251                              *ConstantInt::getBool(Ctx, 1));
3252     A.changeUseAfterManifest(
3253         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
3254         *ConstantInt::getBool(Ctx, 0));
3255     A.changeUseAfterManifest(
3256         KernelDeinitCB->getArgOperandUse(DeinitIsSPMDArgNo),
3257         *ConstantInt::getBool(Ctx, 1));
3258     A.changeUseAfterManifest(
3259         KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo),
3260         *ConstantInt::getBool(Ctx, 0));
3261     A.changeUseAfterManifest(
3262         KernelDeinitCB->getArgOperandUse(DeinitRequiresFullRuntimeArgNo),
3263         *ConstantInt::getBool(Ctx, 0));
3264 
3265     ++NumOpenMPTargetRegionKernelsSPMD;
3266 
3267     auto Remark = [&](OptimizationRemark OR) {
3268       return OR << "Transformed generic-mode kernel to SPMD-mode.";
3269     };
3270     A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
3271     return true;
3272   };
3273 
3274   ChangeStatus buildCustomStateMachine(Attributor &A) {
3275     // If we have disabled state machine rewrites, don't make a custom one
3276     if (DisableOpenMPOptStateMachineRewrite)
3277       return indicatePessimisticFixpoint();
3278 
3279     assert(ReachedKnownParallelRegions.isValidState() &&
3280            "Custom state machine with invalid parallel region states?");
3281 
3282     const int InitIsSPMDArgNo = 1;
3283     const int InitUseStateMachineArgNo = 2;
3284 
3285     // Check if the current configuration is non-SPMD and generic state machine.
3286     // If we already have SPMD mode or a custom state machine we do not need to
3287     // go any further. If it is anything but a constant something is weird and
3288     // we give up.
3289     ConstantInt *UseStateMachine = dyn_cast<ConstantInt>(
3290         KernelInitCB->getArgOperand(InitUseStateMachineArgNo));
3291     ConstantInt *IsSPMD =
3292         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo));
3293 
3294     // If we are stuck with generic mode, try to create a custom device (=GPU)
3295     // state machine which is specialized for the parallel regions that are
3296     // reachable by the kernel.
3297     if (!UseStateMachine || UseStateMachine->isZero() || !IsSPMD ||
3298         !IsSPMD->isZero())
3299       return ChangeStatus::UNCHANGED;
3300 
3301     // If not SPMD mode, indicate we use a custom state machine now.
3302     auto &Ctx = getAnchorValue().getContext();
3303     auto *FalseVal = ConstantInt::getBool(Ctx, 0);
3304     A.changeUseAfterManifest(
3305         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal);
3306 
3307     // If we don't actually need a state machine we are done here. This can
3308     // happen if there simply are no parallel regions. In the resulting kernel
3309     // all worker threads will simply exit right away, leaving the main thread
3310     // to do the work alone.
3311     if (ReachedKnownParallelRegions.empty() &&
3312         ReachedUnknownParallelRegions.empty()) {
3313       ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
3314 
3315       auto Remark = [&](OptimizationRemark OR) {
3316         return OR << "Removing unused state machine from generic-mode kernel.";
3317       };
3318       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
3319 
3320       return ChangeStatus::CHANGED;
3321     }
3322 
3323     // Keep track in the statistics of our new shiny custom state machine.
3324     if (ReachedUnknownParallelRegions.empty()) {
3325       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
3326 
3327       auto Remark = [&](OptimizationRemark OR) {
3328         return OR << "Rewriting generic-mode kernel with a customized state "
3329                      "machine.";
3330       };
3331       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
3332     } else {
3333       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
3334 
3335       auto Remark = [&](OptimizationRemarkAnalysis OR) {
3336         return OR << "Generic-mode kernel is executed with a customized state "
3337                      "machine that requires a fallback.";
3338       };
3339       A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
3340 
3341       // Tell the user why we ended up with a fallback.
3342       for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
3343         if (!UnknownParallelRegionCB)
3344           continue;
3345         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3346           return ORA << "Call may contain unknown parallel regions. Use "
3347                      << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
3348                         "override.";
3349         };
3350         A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
3351                                                  "OMP133", Remark);
3352       }
3353     }
3354 
3355     // Create all the blocks:
3356     //
3357     //                       InitCB = __kmpc_target_init(...)
3358     //                       bool IsWorker = InitCB >= 0;
3359     //                       if (IsWorker) {
3360     // SMBeginBB:               __kmpc_barrier_simple_spmd(...);
3361     //                         void *WorkFn;
3362     //                         bool Active = __kmpc_kernel_parallel(&WorkFn);
3363     //                         if (!WorkFn) return;
3364     // SMIsActiveCheckBB:       if (Active) {
3365     // SMIfCascadeCurrentBB:      if      (WorkFn == <ParFn0>)
3366     //                              ParFn0(...);
3367     // SMIfCascadeCurrentBB:      else if (WorkFn == <ParFn1>)
3368     //                              ParFn1(...);
3369     //                            ...
3370     // SMIfCascadeCurrentBB:      else
3371     //                              ((WorkFnTy*)WorkFn)(...);
3372     // SMEndParallelBB:           __kmpc_kernel_end_parallel(...);
3373     //                          }
3374     // SMDoneBB:                __kmpc_barrier_simple_spmd(...);
3375     //                          goto SMBeginBB;
3376     //                       }
3377     // UserCodeEntryBB:      // user code
3378     //                       __kmpc_target_deinit(...)
3379     //
3380     Function *Kernel = getAssociatedFunction();
3381     assert(Kernel && "Expected an associated function!");
3382 
3383     BasicBlock *InitBB = KernelInitCB->getParent();
3384     BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
3385         KernelInitCB->getNextNode(), "thread.user_code.check");
3386     BasicBlock *StateMachineBeginBB = BasicBlock::Create(
3387         Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
3388     BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
3389         Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
3390     BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
3391         Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
3392     BasicBlock *StateMachineIfCascadeCurrentBB =
3393         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3394                            Kernel, UserCodeEntryBB);
3395     BasicBlock *StateMachineEndParallelBB =
3396         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
3397                            Kernel, UserCodeEntryBB);
3398     BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
3399         Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
3400     A.registerManifestAddedBasicBlock(*InitBB);
3401     A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
3402     A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
3403     A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
3404     A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
3405     A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
3406     A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
3407     A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
3408 
3409     const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
3410     ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
3411 
3412     InitBB->getTerminator()->eraseFromParent();
3413     Instruction *IsWorker =
3414         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
3415                          ConstantInt::get(KernelInitCB->getType(), -1),
3416                          "thread.is_worker", InitBB);
3417     IsWorker->setDebugLoc(DLoc);
3418     BranchInst::Create(StateMachineBeginBB, UserCodeEntryBB, IsWorker, InitBB);
3419 
3420     // Create local storage for the work function pointer.
3421     Type *VoidPtrTy = Type::getInt8PtrTy(Ctx);
3422     AllocaInst *WorkFnAI = new AllocaInst(VoidPtrTy, 0, "worker.work_fn.addr",
3423                                           &Kernel->getEntryBlock().front());
3424     WorkFnAI->setDebugLoc(DLoc);
3425 
3426     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3427     OMPInfoCache.OMPBuilder.updateToLocation(
3428         OpenMPIRBuilder::LocationDescription(
3429             IRBuilder<>::InsertPoint(StateMachineBeginBB,
3430                                      StateMachineBeginBB->end()),
3431             DLoc));
3432 
3433     Value *Ident = KernelInitCB->getArgOperand(0);
3434     Value *GTid = KernelInitCB;
3435 
3436     Module &M = *Kernel->getParent();
3437     FunctionCallee BarrierFn =
3438         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3439             M, OMPRTL___kmpc_barrier_simple_spmd);
3440     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB)
3441         ->setDebugLoc(DLoc);
3442 
3443     FunctionCallee KernelParallelFn =
3444         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3445             M, OMPRTL___kmpc_kernel_parallel);
3446     Instruction *IsActiveWorker = CallInst::Create(
3447         KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
3448     IsActiveWorker->setDebugLoc(DLoc);
3449     Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
3450                                        StateMachineBeginBB);
3451     WorkFn->setDebugLoc(DLoc);
3452 
3453     FunctionType *ParallelRegionFnTy = FunctionType::get(
3454         Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
3455         false);
3456     Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast(
3457         WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast",
3458         StateMachineBeginBB);
3459 
3460     Instruction *IsDone =
3461         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
3462                          Constant::getNullValue(VoidPtrTy), "worker.is_done",
3463                          StateMachineBeginBB);
3464     IsDone->setDebugLoc(DLoc);
3465     BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
3466                        IsDone, StateMachineBeginBB)
3467         ->setDebugLoc(DLoc);
3468 
3469     BranchInst::Create(StateMachineIfCascadeCurrentBB,
3470                        StateMachineDoneBarrierBB, IsActiveWorker,
3471                        StateMachineIsActiveCheckBB)
3472         ->setDebugLoc(DLoc);
3473 
3474     Value *ZeroArg =
3475         Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
3476 
3477     // Now that we have most of the CFG skeleton it is time for the if-cascade
3478     // that checks the function pointer we got from the runtime against the
3479     // parallel regions we expect, if there are any.
3480     for (int i = 0, e = ReachedKnownParallelRegions.size(); i < e; ++i) {
3481       auto *ParallelRegion = ReachedKnownParallelRegions[i];
3482       BasicBlock *PRExecuteBB = BasicBlock::Create(
3483           Ctx, "worker_state_machine.parallel_region.execute", Kernel,
3484           StateMachineEndParallelBB);
3485       CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
3486           ->setDebugLoc(DLoc);
3487       BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
3488           ->setDebugLoc(DLoc);
3489 
3490       BasicBlock *PRNextBB =
3491           BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3492                              Kernel, StateMachineEndParallelBB);
3493 
3494       // Check if we need to compare the pointer at all or if we can just
3495       // call the parallel region function.
3496       Value *IsPR;
3497       if (i + 1 < e || !ReachedUnknownParallelRegions.empty()) {
3498         Instruction *CmpI = ICmpInst::Create(
3499             ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion,
3500             "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
3501         CmpI->setDebugLoc(DLoc);
3502         IsPR = CmpI;
3503       } else {
3504         IsPR = ConstantInt::getTrue(Ctx);
3505       }
3506 
3507       BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
3508                          StateMachineIfCascadeCurrentBB)
3509           ->setDebugLoc(DLoc);
3510       StateMachineIfCascadeCurrentBB = PRNextBB;
3511     }
3512 
3513     // At the end of the if-cascade we place the indirect function pointer call
3514     // in case we might need it, that is if there can be parallel regions we
3515     // have not handled in the if-cascade above.
3516     if (!ReachedUnknownParallelRegions.empty()) {
3517       StateMachineIfCascadeCurrentBB->setName(
3518           "worker_state_machine.parallel_region.fallback.execute");
3519       CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "",
3520                        StateMachineIfCascadeCurrentBB)
3521           ->setDebugLoc(DLoc);
3522     }
3523     BranchInst::Create(StateMachineEndParallelBB,
3524                        StateMachineIfCascadeCurrentBB)
3525         ->setDebugLoc(DLoc);
3526 
3527     CallInst::Create(OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3528                          M, OMPRTL___kmpc_kernel_end_parallel),
3529                      {}, "", StateMachineEndParallelBB)
3530         ->setDebugLoc(DLoc);
3531     BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
3532         ->setDebugLoc(DLoc);
3533 
3534     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
3535         ->setDebugLoc(DLoc);
3536     BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
3537         ->setDebugLoc(DLoc);
3538 
3539     return ChangeStatus::CHANGED;
3540   }
3541 
3542   /// Fixpoint iteration update function. Will be called every time a dependence
3543   /// changed its state (and in the beginning).
3544   ChangeStatus updateImpl(Attributor &A) override {
3545     KernelInfoState StateBefore = getState();
3546 
3547     // Callback to check a read/write instruction.
3548     auto CheckRWInst = [&](Instruction &I) {
3549       // We handle calls later.
3550       if (isa<CallBase>(I))
3551         return true;
3552       // We only care about write effects.
3553       if (!I.mayWriteToMemory())
3554         return true;
3555       if (auto *SI = dyn_cast<StoreInst>(&I)) {
3556         SmallVector<const Value *> Objects;
3557         getUnderlyingObjects(SI->getPointerOperand(), Objects);
3558         if (llvm::all_of(Objects,
3559                          [](const Value *Obj) { return isa<AllocaInst>(Obj); }))
3560           return true;
3561         // Check for AAHeapToStack moved objects which must not be guarded.
3562         auto &HS = A.getAAFor<AAHeapToStack>(
3563             *this, IRPosition::function(*I.getFunction()),
3564             DepClassTy::REQUIRED);
3565         if (llvm::all_of(Objects, [&HS](const Value *Obj) {
3566               auto *CB = dyn_cast<CallBase>(Obj);
3567               if (!CB)
3568                 return false;
3569               return HS.isAssumedHeapToStack(*CB);
3570             })) {
3571           return true;
3572         }
3573       }
3574 
3575       // Insert instruction that needs guarding.
3576       SPMDCompatibilityTracker.insert(&I);
3577       return true;
3578     };
3579 
3580     bool UsedAssumedInformationInCheckRWInst = false;
3581     if (!SPMDCompatibilityTracker.isAtFixpoint())
3582       if (!A.checkForAllReadWriteInstructions(
3583               CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
3584         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3585 
3586     if (!IsKernelEntry) {
3587       updateReachingKernelEntries(A);
3588       updateParallelLevels(A);
3589 
3590       if (!ParallelLevels.isValidState())
3591         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3592     }
3593 
3594     // Callback to check a call instruction.
3595     bool AllSPMDStatesWereFixed = true;
3596     auto CheckCallInst = [&](Instruction &I) {
3597       auto &CB = cast<CallBase>(I);
3598       auto &CBAA = A.getAAFor<AAKernelInfo>(
3599           *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
3600       getState() ^= CBAA.getState();
3601       AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint();
3602       return true;
3603     };
3604 
3605     bool UsedAssumedInformationInCheckCallInst = false;
3606     if (!A.checkForAllCallLikeInstructions(
3607             CheckCallInst, *this, UsedAssumedInformationInCheckCallInst))
3608       return indicatePessimisticFixpoint();
3609 
3610     // If we haven't used any assumed information for the SPMD state we can fix
3611     // it.
3612     if (!UsedAssumedInformationInCheckRWInst &&
3613         !UsedAssumedInformationInCheckCallInst && AllSPMDStatesWereFixed)
3614       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3615 
3616     return StateBefore == getState() ? ChangeStatus::UNCHANGED
3617                                      : ChangeStatus::CHANGED;
3618   }
3619 
3620 private:
3621   /// Update info regarding reaching kernels.
3622   void updateReachingKernelEntries(Attributor &A) {
3623     auto PredCallSite = [&](AbstractCallSite ACS) {
3624       Function *Caller = ACS.getInstruction()->getFunction();
3625 
3626       assert(Caller && "Caller is nullptr");
3627 
3628       auto &CAA = A.getOrCreateAAFor<AAKernelInfo>(
3629           IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
3630       if (CAA.ReachingKernelEntries.isValidState()) {
3631         ReachingKernelEntries ^= CAA.ReachingKernelEntries;
3632         return true;
3633       }
3634 
3635       // We lost track of the caller of the associated function, any kernel
3636       // could reach now.
3637       ReachingKernelEntries.indicatePessimisticFixpoint();
3638 
3639       return true;
3640     };
3641 
3642     bool AllCallSitesKnown;
3643     if (!A.checkForAllCallSites(PredCallSite, *this,
3644                                 true /* RequireAllCallSites */,
3645                                 AllCallSitesKnown))
3646       ReachingKernelEntries.indicatePessimisticFixpoint();
3647   }
3648 
3649   /// Update info regarding parallel levels.
3650   void updateParallelLevels(Attributor &A) {
3651     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3652     OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
3653         OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
3654 
3655     auto PredCallSite = [&](AbstractCallSite ACS) {
3656       Function *Caller = ACS.getInstruction()->getFunction();
3657 
3658       assert(Caller && "Caller is nullptr");
3659 
3660       auto &CAA =
3661           A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
3662       if (CAA.ParallelLevels.isValidState()) {
3663         // Any function that is called by `__kmpc_parallel_51` will not be
3664         // folded as the parallel level in the function is updated. In order to
3665         // get it right, all the analysis would depend on the implentation. That
3666         // said, if in the future any change to the implementation, the analysis
3667         // could be wrong. As a consequence, we are just conservative here.
3668         if (Caller == Parallel51RFI.Declaration) {
3669           ParallelLevels.indicatePessimisticFixpoint();
3670           return true;
3671         }
3672 
3673         ParallelLevels ^= CAA.ParallelLevels;
3674 
3675         return true;
3676       }
3677 
3678       // We lost track of the caller of the associated function, any kernel
3679       // could reach now.
3680       ParallelLevels.indicatePessimisticFixpoint();
3681 
3682       return true;
3683     };
3684 
3685     bool AllCallSitesKnown = true;
3686     if (!A.checkForAllCallSites(PredCallSite, *this,
3687                                 true /* RequireAllCallSites */,
3688                                 AllCallSitesKnown))
3689       ParallelLevels.indicatePessimisticFixpoint();
3690   }
3691 };
3692 
3693 /// The call site kernel info abstract attribute, basically, what can we say
3694 /// about a call site with regards to the KernelInfoState. For now this simply
3695 /// forwards the information from the callee.
3696 struct AAKernelInfoCallSite : AAKernelInfo {
3697   AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
3698       : AAKernelInfo(IRP, A) {}
3699 
3700   /// See AbstractAttribute::initialize(...).
3701   void initialize(Attributor &A) override {
3702     AAKernelInfo::initialize(A);
3703 
3704     CallBase &CB = cast<CallBase>(getAssociatedValue());
3705     Function *Callee = getAssociatedFunction();
3706 
3707     // Helper to lookup an assumption string.
3708     auto HasAssumption = [](Function *Fn, StringRef AssumptionStr) {
3709       return Fn && hasAssumption(*Fn, AssumptionStr);
3710     };
3711 
3712     // Check for SPMD-mode assumptions.
3713     if (HasAssumption(Callee, "ompx_spmd_amenable"))
3714       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3715 
3716     // First weed out calls we do not care about, that is readonly/readnone
3717     // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
3718     // parallel region or anything else we are looking for.
3719     if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
3720       indicateOptimisticFixpoint();
3721       return;
3722     }
3723 
3724     // Next we check if we know the callee. If it is a known OpenMP function
3725     // we will handle them explicitly in the switch below. If it is not, we
3726     // will use an AAKernelInfo object on the callee to gather information and
3727     // merge that into the current state. The latter happens in the updateImpl.
3728     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3729     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
3730     if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
3731       // Unknown caller or declarations are not analyzable, we give up.
3732       if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
3733 
3734         // Unknown callees might contain parallel regions, except if they have
3735         // an appropriate assumption attached.
3736         if (!(HasAssumption(Callee, "omp_no_openmp") ||
3737               HasAssumption(Callee, "omp_no_parallelism")))
3738           ReachedUnknownParallelRegions.insert(&CB);
3739 
3740         // If SPMDCompatibilityTracker is not fixed, we need to give up on the
3741         // idea we can run something unknown in SPMD-mode.
3742         if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3743           SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3744           SPMDCompatibilityTracker.insert(&CB);
3745         }
3746 
3747         // We have updated the state for this unknown call properly, there won't
3748         // be any change so we indicate a fixpoint.
3749         indicateOptimisticFixpoint();
3750       }
3751       // If the callee is known and can be used in IPO, we will update the state
3752       // based on the callee state in updateImpl.
3753       return;
3754     }
3755 
3756     const unsigned int WrapperFunctionArgNo = 6;
3757     RuntimeFunction RF = It->getSecond();
3758     switch (RF) {
3759     // All the functions we know are compatible with SPMD mode.
3760     case OMPRTL___kmpc_is_spmd_exec_mode:
3761     case OMPRTL___kmpc_for_static_fini:
3762     case OMPRTL___kmpc_global_thread_num:
3763     case OMPRTL___kmpc_get_hardware_num_threads_in_block:
3764     case OMPRTL___kmpc_get_hardware_num_blocks:
3765     case OMPRTL___kmpc_single:
3766     case OMPRTL___kmpc_end_single:
3767     case OMPRTL___kmpc_master:
3768     case OMPRTL___kmpc_end_master:
3769     case OMPRTL___kmpc_barrier:
3770       break;
3771     case OMPRTL___kmpc_for_static_init_4:
3772     case OMPRTL___kmpc_for_static_init_4u:
3773     case OMPRTL___kmpc_for_static_init_8:
3774     case OMPRTL___kmpc_for_static_init_8u: {
3775       // Check the schedule and allow static schedule in SPMD mode.
3776       unsigned ScheduleArgOpNo = 2;
3777       auto *ScheduleTypeCI =
3778           dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
3779       unsigned ScheduleTypeVal =
3780           ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
3781       switch (OMPScheduleType(ScheduleTypeVal)) {
3782       case OMPScheduleType::Static:
3783       case OMPScheduleType::StaticChunked:
3784       case OMPScheduleType::Distribute:
3785       case OMPScheduleType::DistributeChunked:
3786         break;
3787       default:
3788         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3789         SPMDCompatibilityTracker.insert(&CB);
3790         break;
3791       };
3792     } break;
3793     case OMPRTL___kmpc_target_init:
3794       KernelInitCB = &CB;
3795       break;
3796     case OMPRTL___kmpc_target_deinit:
3797       KernelDeinitCB = &CB;
3798       break;
3799     case OMPRTL___kmpc_parallel_51:
3800       if (auto *ParallelRegion = dyn_cast<Function>(
3801               CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) {
3802         ReachedKnownParallelRegions.insert(ParallelRegion);
3803         break;
3804       }
3805       // The condition above should usually get the parallel region function
3806       // pointer and record it. In the off chance it doesn't we assume the
3807       // worst.
3808       ReachedUnknownParallelRegions.insert(&CB);
3809       break;
3810     case OMPRTL___kmpc_omp_task:
3811       // We do not look into tasks right now, just give up.
3812       SPMDCompatibilityTracker.insert(&CB);
3813       ReachedUnknownParallelRegions.insert(&CB);
3814       indicatePessimisticFixpoint();
3815       return;
3816     case OMPRTL___kmpc_alloc_shared:
3817     case OMPRTL___kmpc_free_shared:
3818       // Return without setting a fixpoint, to be resolved in updateImpl.
3819       return;
3820     default:
3821       // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
3822       // generally.
3823       SPMDCompatibilityTracker.insert(&CB);
3824       indicatePessimisticFixpoint();
3825       return;
3826     }
3827     // All other OpenMP runtime calls will not reach parallel regions so they
3828     // can be safely ignored for now. Since it is a known OpenMP runtime call we
3829     // have now modeled all effects and there is no need for any update.
3830     indicateOptimisticFixpoint();
3831   }
3832 
3833   ChangeStatus updateImpl(Attributor &A) override {
3834     // TODO: Once we have call site specific value information we can provide
3835     //       call site specific liveness information and then it makes
3836     //       sense to specialize attributes for call sites arguments instead of
3837     //       redirecting requests to the callee argument.
3838     Function *F = getAssociatedFunction();
3839 
3840     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3841     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
3842 
3843     // If F is not a runtime function, propagate the AAKernelInfo of the callee.
3844     if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
3845       const IRPosition &FnPos = IRPosition::function(*F);
3846       auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
3847       if (getState() == FnAA.getState())
3848         return ChangeStatus::UNCHANGED;
3849       getState() = FnAA.getState();
3850       return ChangeStatus::CHANGED;
3851     }
3852 
3853     // F is a runtime function that allocates or frees memory, check
3854     // AAHeapToStack and AAHeapToShared.
3855     KernelInfoState StateBefore = getState();
3856     assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
3857             It->getSecond() == OMPRTL___kmpc_free_shared) &&
3858            "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
3859 
3860     CallBase &CB = cast<CallBase>(getAssociatedValue());
3861 
3862     auto &HeapToStackAA = A.getAAFor<AAHeapToStack>(
3863         *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
3864     auto &HeapToSharedAA = A.getAAFor<AAHeapToShared>(
3865         *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
3866 
3867     RuntimeFunction RF = It->getSecond();
3868 
3869     switch (RF) {
3870     // If neither HeapToStack nor HeapToShared assume the call is removed,
3871     // assume SPMD incompatibility.
3872     case OMPRTL___kmpc_alloc_shared:
3873       if (!HeapToStackAA.isAssumedHeapToStack(CB) &&
3874           !HeapToSharedAA.isAssumedHeapToShared(CB))
3875         SPMDCompatibilityTracker.insert(&CB);
3876       break;
3877     case OMPRTL___kmpc_free_shared:
3878       if (!HeapToStackAA.isAssumedHeapToStackRemovedFree(CB) &&
3879           !HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB))
3880         SPMDCompatibilityTracker.insert(&CB);
3881       break;
3882     default:
3883       SPMDCompatibilityTracker.insert(&CB);
3884     }
3885 
3886     return StateBefore == getState() ? ChangeStatus::UNCHANGED
3887                                      : ChangeStatus::CHANGED;
3888   }
3889 };
3890 
3891 struct AAFoldRuntimeCall
3892     : public StateWrapper<BooleanState, AbstractAttribute> {
3893   using Base = StateWrapper<BooleanState, AbstractAttribute>;
3894 
3895   AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3896 
3897   /// Statistics are tracked as part of manifest for now.
3898   void trackStatistics() const override {}
3899 
3900   /// Create an abstract attribute biew for the position \p IRP.
3901   static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
3902                                               Attributor &A);
3903 
3904   /// See AbstractAttribute::getName()
3905   const std::string getName() const override { return "AAFoldRuntimeCall"; }
3906 
3907   /// See AbstractAttribute::getIdAddr()
3908   const char *getIdAddr() const override { return &ID; }
3909 
3910   /// This function should return true if the type of the \p AA is
3911   /// AAFoldRuntimeCall
3912   static bool classof(const AbstractAttribute *AA) {
3913     return (AA->getIdAddr() == &ID);
3914   }
3915 
3916   static const char ID;
3917 };
3918 
3919 struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
3920   AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
3921       : AAFoldRuntimeCall(IRP, A) {}
3922 
3923   /// See AbstractAttribute::getAsStr()
3924   const std::string getAsStr() const override {
3925     if (!isValidState())
3926       return "<invalid>";
3927 
3928     std::string Str("simplified value: ");
3929 
3930     if (!SimplifiedValue.hasValue())
3931       return Str + std::string("none");
3932 
3933     if (!SimplifiedValue.getValue())
3934       return Str + std::string("nullptr");
3935 
3936     if (ConstantInt *CI = dyn_cast<ConstantInt>(SimplifiedValue.getValue()))
3937       return Str + std::to_string(CI->getSExtValue());
3938 
3939     return Str + std::string("unknown");
3940   }
3941 
3942   void initialize(Attributor &A) override {
3943     if (DisableOpenMPOptFolding)
3944       indicatePessimisticFixpoint();
3945 
3946     Function *Callee = getAssociatedFunction();
3947 
3948     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3949     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
3950     assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
3951            "Expected a known OpenMP runtime function");
3952 
3953     RFKind = It->getSecond();
3954 
3955     CallBase &CB = cast<CallBase>(getAssociatedValue());
3956     A.registerSimplificationCallback(
3957         IRPosition::callsite_returned(CB),
3958         [&](const IRPosition &IRP, const AbstractAttribute *AA,
3959             bool &UsedAssumedInformation) -> Optional<Value *> {
3960           assert((isValidState() || (SimplifiedValue.hasValue() &&
3961                                      SimplifiedValue.getValue() == nullptr)) &&
3962                  "Unexpected invalid state!");
3963 
3964           if (!isAtFixpoint()) {
3965             UsedAssumedInformation = true;
3966             if (AA)
3967               A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3968           }
3969           return SimplifiedValue;
3970         });
3971   }
3972 
3973   ChangeStatus updateImpl(Attributor &A) override {
3974     ChangeStatus Changed = ChangeStatus::UNCHANGED;
3975     switch (RFKind) {
3976     case OMPRTL___kmpc_is_spmd_exec_mode:
3977       Changed |= foldIsSPMDExecMode(A);
3978       break;
3979     case OMPRTL___kmpc_is_generic_main_thread_id:
3980       Changed |= foldIsGenericMainThread(A);
3981       break;
3982     case OMPRTL___kmpc_parallel_level:
3983       Changed |= foldParallelLevel(A);
3984       break;
3985     case OMPRTL___kmpc_get_hardware_num_threads_in_block:
3986       Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
3987       break;
3988     case OMPRTL___kmpc_get_hardware_num_blocks:
3989       Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
3990       break;
3991     default:
3992       llvm_unreachable("Unhandled OpenMP runtime function!");
3993     }
3994 
3995     return Changed;
3996   }
3997 
3998   ChangeStatus manifest(Attributor &A) override {
3999     ChangeStatus Changed = ChangeStatus::UNCHANGED;
4000 
4001     if (SimplifiedValue.hasValue() && SimplifiedValue.getValue()) {
4002       Instruction &CB = *getCtxI();
4003       A.changeValueAfterManifest(CB, **SimplifiedValue);
4004       A.deleteAfterManifest(CB);
4005 
4006       LLVM_DEBUG(dbgs() << TAG << "Folding runtime call: " << CB << " with "
4007                         << **SimplifiedValue << "\n");
4008 
4009       Changed = ChangeStatus::CHANGED;
4010     }
4011 
4012     return Changed;
4013   }
4014 
4015   ChangeStatus indicatePessimisticFixpoint() override {
4016     SimplifiedValue = nullptr;
4017     return AAFoldRuntimeCall::indicatePessimisticFixpoint();
4018   }
4019 
4020 private:
4021   /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
4022   ChangeStatus foldIsSPMDExecMode(Attributor &A) {
4023     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4024 
4025     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4026     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4027     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4028         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4029 
4030     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4031       return indicatePessimisticFixpoint();
4032 
4033     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4034       auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4035                                           DepClassTy::REQUIRED);
4036 
4037       if (!AA.isValidState()) {
4038         SimplifiedValue = nullptr;
4039         return indicatePessimisticFixpoint();
4040       }
4041 
4042       if (AA.SPMDCompatibilityTracker.isAssumed()) {
4043         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4044           ++KnownSPMDCount;
4045         else
4046           ++AssumedSPMDCount;
4047       } else {
4048         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4049           ++KnownNonSPMDCount;
4050         else
4051           ++AssumedNonSPMDCount;
4052       }
4053     }
4054 
4055     if ((AssumedSPMDCount + KnownSPMDCount) &&
4056         (AssumedNonSPMDCount + KnownNonSPMDCount))
4057       return indicatePessimisticFixpoint();
4058 
4059     auto &Ctx = getAnchorValue().getContext();
4060     if (KnownSPMDCount || AssumedSPMDCount) {
4061       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4062              "Expected only SPMD kernels!");
4063       // All reaching kernels are in SPMD mode. Update all function calls to
4064       // __kmpc_is_spmd_exec_mode to 1.
4065       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4066     } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
4067       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4068              "Expected only non-SPMD kernels!");
4069       // All reaching kernels are in non-SPMD mode. Update all function
4070       // calls to __kmpc_is_spmd_exec_mode to 0.
4071       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
4072     } else {
4073       // We have empty reaching kernels, therefore we cannot tell if the
4074       // associated call site can be folded. At this moment, SimplifiedValue
4075       // must be none.
4076       assert(!SimplifiedValue.hasValue() && "SimplifiedValue should be none");
4077     }
4078 
4079     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4080                                                     : ChangeStatus::CHANGED;
4081   }
4082 
4083   /// Fold __kmpc_is_generic_main_thread_id into a constant if possible.
4084   ChangeStatus foldIsGenericMainThread(Attributor &A) {
4085     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4086 
4087     CallBase &CB = cast<CallBase>(getAssociatedValue());
4088     Function *F = CB.getFunction();
4089     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
4090         *this, IRPosition::function(*F), DepClassTy::REQUIRED);
4091 
4092     if (!ExecutionDomainAA.isValidState())
4093       return indicatePessimisticFixpoint();
4094 
4095     auto &Ctx = getAnchorValue().getContext();
4096     if (ExecutionDomainAA.isExecutedByInitialThreadOnly(CB))
4097       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4098     else
4099       return indicatePessimisticFixpoint();
4100 
4101     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4102                                                     : ChangeStatus::CHANGED;
4103   }
4104 
4105   /// Fold __kmpc_parallel_level into a constant if possible.
4106   ChangeStatus foldParallelLevel(Attributor &A) {
4107     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4108 
4109     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4110         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4111 
4112     if (!CallerKernelInfoAA.ParallelLevels.isValidState())
4113       return indicatePessimisticFixpoint();
4114 
4115     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4116       return indicatePessimisticFixpoint();
4117 
4118     if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
4119       assert(!SimplifiedValue.hasValue() &&
4120              "SimplifiedValue should keep none at this point");
4121       return ChangeStatus::UNCHANGED;
4122     }
4123 
4124     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4125     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4126     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4127       auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4128                                           DepClassTy::REQUIRED);
4129       if (!AA.SPMDCompatibilityTracker.isValidState())
4130         return indicatePessimisticFixpoint();
4131 
4132       if (AA.SPMDCompatibilityTracker.isAssumed()) {
4133         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4134           ++KnownSPMDCount;
4135         else
4136           ++AssumedSPMDCount;
4137       } else {
4138         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4139           ++KnownNonSPMDCount;
4140         else
4141           ++AssumedNonSPMDCount;
4142       }
4143     }
4144 
4145     if ((AssumedSPMDCount + KnownSPMDCount) &&
4146         (AssumedNonSPMDCount + KnownNonSPMDCount))
4147       return indicatePessimisticFixpoint();
4148 
4149     auto &Ctx = getAnchorValue().getContext();
4150     // If the caller can only be reached by SPMD kernel entries, the parallel
4151     // level is 1. Similarly, if the caller can only be reached by non-SPMD
4152     // kernel entries, it is 0.
4153     if (AssumedSPMDCount || KnownSPMDCount) {
4154       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4155              "Expected only SPMD kernels!");
4156       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
4157     } else {
4158       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4159              "Expected only non-SPMD kernels!");
4160       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
4161     }
4162     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4163                                                     : ChangeStatus::CHANGED;
4164   }
4165 
4166   ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
4167     // Specialize only if all the calls agree with the attribute constant value
4168     int32_t CurrentAttrValue = -1;
4169     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4170 
4171     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4172         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4173 
4174     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4175       return indicatePessimisticFixpoint();
4176 
4177     // Iterate over the kernels that reach this function
4178     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4179       int32_t NextAttrVal = -1;
4180       if (K->hasFnAttribute(Attr))
4181         NextAttrVal =
4182             std::stoi(K->getFnAttribute(Attr).getValueAsString().str());
4183 
4184       if (NextAttrVal == -1 ||
4185           (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
4186         return indicatePessimisticFixpoint();
4187       CurrentAttrValue = NextAttrVal;
4188     }
4189 
4190     if (CurrentAttrValue != -1) {
4191       auto &Ctx = getAnchorValue().getContext();
4192       SimplifiedValue =
4193           ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
4194     }
4195     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4196                                                     : ChangeStatus::CHANGED;
4197   }
4198 
4199   /// An optional value the associated value is assumed to fold to. That is, we
4200   /// assume the associated value (which is a call) can be replaced by this
4201   /// simplified value.
4202   Optional<Value *> SimplifiedValue;
4203 
4204   /// The runtime function kind of the callee of the associated call site.
4205   RuntimeFunction RFKind;
4206 };
4207 
4208 } // namespace
4209 
4210 /// Register folding callsite
4211 void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
4212   auto &RFI = OMPInfoCache.RFIs[RF];
4213   RFI.foreachUse(SCC, [&](Use &U, Function &F) {
4214     CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
4215     if (!CI)
4216       return false;
4217     A.getOrCreateAAFor<AAFoldRuntimeCall>(
4218         IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
4219         DepClassTy::NONE, /* ForceUpdate */ false,
4220         /* UpdateAfterInit */ false);
4221     return false;
4222   });
4223 }
4224 
4225 void OpenMPOpt::registerAAs(bool IsModulePass) {
4226   if (SCC.empty())
4227 
4228     return;
4229   if (IsModulePass) {
4230     // Ensure we create the AAKernelInfo AAs first and without triggering an
4231     // update. This will make sure we register all value simplification
4232     // callbacks before any other AA has the chance to create an AAValueSimplify
4233     // or similar.
4234     for (Function *Kernel : OMPInfoCache.Kernels)
4235       A.getOrCreateAAFor<AAKernelInfo>(
4236           IRPosition::function(*Kernel), /* QueryingAA */ nullptr,
4237           DepClassTy::NONE, /* ForceUpdate */ false,
4238           /* UpdateAfterInit */ false);
4239 
4240 
4241     registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id);
4242     registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
4243     registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
4244     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
4245     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
4246   }
4247 
4248   // Create CallSite AA for all Getters.
4249   for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
4250     auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
4251 
4252     auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
4253 
4254     auto CreateAA = [&](Use &U, Function &Caller) {
4255       CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
4256       if (!CI)
4257         return false;
4258 
4259       auto &CB = cast<CallBase>(*CI);
4260 
4261       IRPosition CBPos = IRPosition::callsite_function(CB);
4262       A.getOrCreateAAFor<AAICVTracker>(CBPos);
4263       return false;
4264     };
4265 
4266     GetterRFI.foreachUse(SCC, CreateAA);
4267   }
4268   auto &GlobalizationRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4269   auto CreateAA = [&](Use &U, Function &F) {
4270     A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
4271     return false;
4272   };
4273   if (!DisableOpenMPOptDeglobalization)
4274     GlobalizationRFI.foreachUse(SCC, CreateAA);
4275 
4276   // Create an ExecutionDomain AA for every function and a HeapToStack AA for
4277   // every function if there is a device kernel.
4278   if (!isOpenMPDevice(M))
4279     return;
4280 
4281   for (auto *F : SCC) {
4282     if (F->isDeclaration())
4283       continue;
4284 
4285     A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F));
4286     if (!DisableOpenMPOptDeglobalization)
4287       A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F));
4288 
4289     for (auto &I : instructions(*F)) {
4290       if (auto *LI = dyn_cast<LoadInst>(&I)) {
4291         bool UsedAssumedInformation = false;
4292         A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
4293                                UsedAssumedInformation);
4294       }
4295     }
4296   }
4297 }
4298 
4299 const char AAICVTracker::ID = 0;
4300 const char AAKernelInfo::ID = 0;
4301 const char AAExecutionDomain::ID = 0;
4302 const char AAHeapToShared::ID = 0;
4303 const char AAFoldRuntimeCall::ID = 0;
4304 
4305 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
4306                                               Attributor &A) {
4307   AAICVTracker *AA = nullptr;
4308   switch (IRP.getPositionKind()) {
4309   case IRPosition::IRP_INVALID:
4310   case IRPosition::IRP_FLOAT:
4311   case IRPosition::IRP_ARGUMENT:
4312   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4313     llvm_unreachable("ICVTracker can only be created for function position!");
4314   case IRPosition::IRP_RETURNED:
4315     AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
4316     break;
4317   case IRPosition::IRP_CALL_SITE_RETURNED:
4318     AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
4319     break;
4320   case IRPosition::IRP_CALL_SITE:
4321     AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
4322     break;
4323   case IRPosition::IRP_FUNCTION:
4324     AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
4325     break;
4326   }
4327 
4328   return *AA;
4329 }
4330 
4331 AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP,
4332                                                         Attributor &A) {
4333   AAExecutionDomainFunction *AA = nullptr;
4334   switch (IRP.getPositionKind()) {
4335   case IRPosition::IRP_INVALID:
4336   case IRPosition::IRP_FLOAT:
4337   case IRPosition::IRP_ARGUMENT:
4338   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4339   case IRPosition::IRP_RETURNED:
4340   case IRPosition::IRP_CALL_SITE_RETURNED:
4341   case IRPosition::IRP_CALL_SITE:
4342     llvm_unreachable(
4343         "AAExecutionDomain can only be created for function position!");
4344   case IRPosition::IRP_FUNCTION:
4345     AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
4346     break;
4347   }
4348 
4349   return *AA;
4350 }
4351 
4352 AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
4353                                                   Attributor &A) {
4354   AAHeapToSharedFunction *AA = nullptr;
4355   switch (IRP.getPositionKind()) {
4356   case IRPosition::IRP_INVALID:
4357   case IRPosition::IRP_FLOAT:
4358   case IRPosition::IRP_ARGUMENT:
4359   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4360   case IRPosition::IRP_RETURNED:
4361   case IRPosition::IRP_CALL_SITE_RETURNED:
4362   case IRPosition::IRP_CALL_SITE:
4363     llvm_unreachable(
4364         "AAHeapToShared can only be created for function position!");
4365   case IRPosition::IRP_FUNCTION:
4366     AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
4367     break;
4368   }
4369 
4370   return *AA;
4371 }
4372 
4373 AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
4374                                               Attributor &A) {
4375   AAKernelInfo *AA = nullptr;
4376   switch (IRP.getPositionKind()) {
4377   case IRPosition::IRP_INVALID:
4378   case IRPosition::IRP_FLOAT:
4379   case IRPosition::IRP_ARGUMENT:
4380   case IRPosition::IRP_RETURNED:
4381   case IRPosition::IRP_CALL_SITE_RETURNED:
4382   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4383     llvm_unreachable("KernelInfo can only be created for function position!");
4384   case IRPosition::IRP_CALL_SITE:
4385     AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
4386     break;
4387   case IRPosition::IRP_FUNCTION:
4388     AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
4389     break;
4390   }
4391 
4392   return *AA;
4393 }
4394 
4395 AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
4396                                                         Attributor &A) {
4397   AAFoldRuntimeCall *AA = nullptr;
4398   switch (IRP.getPositionKind()) {
4399   case IRPosition::IRP_INVALID:
4400   case IRPosition::IRP_FLOAT:
4401   case IRPosition::IRP_ARGUMENT:
4402   case IRPosition::IRP_RETURNED:
4403   case IRPosition::IRP_FUNCTION:
4404   case IRPosition::IRP_CALL_SITE:
4405   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4406     llvm_unreachable("KernelInfo can only be created for call site position!");
4407   case IRPosition::IRP_CALL_SITE_RETURNED:
4408     AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
4409     break;
4410   }
4411 
4412   return *AA;
4413 }
4414 
4415 PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
4416   if (!containsOpenMP(M))
4417     return PreservedAnalyses::all();
4418   if (DisableOpenMPOptimizations)
4419     return PreservedAnalyses::all();
4420 
4421   FunctionAnalysisManager &FAM =
4422       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
4423   KernelSet Kernels = getDeviceKernels(M);
4424 
4425   auto IsCalled = [&](Function &F) {
4426     if (Kernels.contains(&F))
4427       return true;
4428     for (const User *U : F.users())
4429       if (!isa<BlockAddress>(U))
4430         return true;
4431     return false;
4432   };
4433 
4434   auto EmitRemark = [&](Function &F) {
4435     auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
4436     ORE.emit([&]() {
4437       OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
4438       return ORA << "Could not internalize function. "
4439                  << "Some optimizations may not be possible. [OMP140]";
4440     });
4441   };
4442 
4443   // Create internal copies of each function if this is a kernel Module. This
4444   // allows iterprocedural passes to see every call edge.
4445   DenseMap<Function *, Function *> InternalizedMap;
4446   if (isOpenMPDevice(M)) {
4447     SmallPtrSet<Function *, 16> InternalizeFns;
4448     for (Function &F : M)
4449       if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
4450           !DisableInternalization) {
4451         if (Attributor::isInternalizable(F)) {
4452           InternalizeFns.insert(&F);
4453         } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
4454           EmitRemark(F);
4455         }
4456       }
4457 
4458     Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
4459   }
4460 
4461   // Look at every function in the Module unless it was internalized.
4462   SmallVector<Function *, 16> SCC;
4463   for (Function &F : M)
4464     if (!F.isDeclaration() && !InternalizedMap.lookup(&F))
4465       SCC.push_back(&F);
4466 
4467   if (SCC.empty())
4468     return PreservedAnalyses::all();
4469 
4470   AnalysisGetter AG(FAM);
4471 
4472   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
4473     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
4474   };
4475 
4476   BumpPtrAllocator Allocator;
4477   CallGraphUpdater CGUpdater;
4478 
4479   SetVector<Function *> Functions(SCC.begin(), SCC.end());
4480   OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels);
4481 
4482   unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4483   Attributor A(Functions, InfoCache, CGUpdater, nullptr, true, false,
4484                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4485 
4486   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4487   bool Changed = OMPOpt.run(true);
4488 
4489   // Optionally inline device functions for potentially better performance.
4490   if (AlwaysInlineDeviceFunctions && isOpenMPDevice(M))
4491     for (Function &F : M)
4492       if (!F.isDeclaration() && !Kernels.contains(&F) &&
4493           !F.hasFnAttribute(Attribute::NoInline))
4494         F.addFnAttr(Attribute::AlwaysInline);
4495 
4496   if (PrintModuleAfterOptimizations)
4497     LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
4498 
4499   if (Changed)
4500     return PreservedAnalyses::none();
4501 
4502   return PreservedAnalyses::all();
4503 }
4504 
4505 PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
4506                                           CGSCCAnalysisManager &AM,
4507                                           LazyCallGraph &CG,
4508                                           CGSCCUpdateResult &UR) {
4509   if (!containsOpenMP(*C.begin()->getFunction().getParent()))
4510     return PreservedAnalyses::all();
4511   if (DisableOpenMPOptimizations)
4512     return PreservedAnalyses::all();
4513 
4514   SmallVector<Function *, 16> SCC;
4515   // If there are kernels in the module, we have to run on all SCC's.
4516   for (LazyCallGraph::Node &N : C) {
4517     Function *Fn = &N.getFunction();
4518     SCC.push_back(Fn);
4519   }
4520 
4521   if (SCC.empty())
4522     return PreservedAnalyses::all();
4523 
4524   Module &M = *C.begin()->getFunction().getParent();
4525 
4526   KernelSet Kernels = getDeviceKernels(M);
4527 
4528   FunctionAnalysisManager &FAM =
4529       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
4530 
4531   AnalysisGetter AG(FAM);
4532 
4533   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
4534     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
4535   };
4536 
4537   BumpPtrAllocator Allocator;
4538   CallGraphUpdater CGUpdater;
4539   CGUpdater.initialize(CG, C, AM, UR);
4540 
4541   SetVector<Function *> Functions(SCC.begin(), SCC.end());
4542   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
4543                                 /*CGSCC*/ Functions, Kernels);
4544 
4545   unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4546   Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
4547                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4548 
4549   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4550   bool Changed = OMPOpt.run(false);
4551 
4552   if (PrintModuleAfterOptimizations)
4553     LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
4554 
4555   if (Changed)
4556     return PreservedAnalyses::none();
4557 
4558   return PreservedAnalyses::all();
4559 }
4560 
4561 namespace {
4562 
4563 struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass {
4564   CallGraphUpdater CGUpdater;
4565   static char ID;
4566 
4567   OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) {
4568     initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry());
4569   }
4570 
4571   void getAnalysisUsage(AnalysisUsage &AU) const override {
4572     CallGraphSCCPass::getAnalysisUsage(AU);
4573   }
4574 
4575   bool runOnSCC(CallGraphSCC &CGSCC) override {
4576     if (!containsOpenMP(CGSCC.getCallGraph().getModule()))
4577       return false;
4578     if (DisableOpenMPOptimizations || skipSCC(CGSCC))
4579       return false;
4580 
4581     SmallVector<Function *, 16> SCC;
4582     // If there are kernels in the module, we have to run on all SCC's.
4583     for (CallGraphNode *CGN : CGSCC) {
4584       Function *Fn = CGN->getFunction();
4585       if (!Fn || Fn->isDeclaration())
4586         continue;
4587       SCC.push_back(Fn);
4588     }
4589 
4590     if (SCC.empty())
4591       return false;
4592 
4593     Module &M = CGSCC.getCallGraph().getModule();
4594     KernelSet Kernels = getDeviceKernels(M);
4595 
4596     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
4597     CGUpdater.initialize(CG, CGSCC);
4598 
4599     // Maintain a map of functions to avoid rebuilding the ORE
4600     DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap;
4601     auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
4602       std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
4603       if (!ORE)
4604         ORE = std::make_unique<OptimizationRemarkEmitter>(F);
4605       return *ORE;
4606     };
4607 
4608     AnalysisGetter AG;
4609     SetVector<Function *> Functions(SCC.begin(), SCC.end());
4610     BumpPtrAllocator Allocator;
4611     OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG,
4612                                   Allocator,
4613                                   /*CGSCC*/ Functions, Kernels);
4614 
4615     unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4616     Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
4617                  MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4618 
4619     OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4620     bool Result = OMPOpt.run(false);
4621 
4622     if (PrintModuleAfterOptimizations)
4623       LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
4624 
4625     return Result;
4626   }
4627 
4628   bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
4629 };
4630 
4631 } // end anonymous namespace
4632 
4633 KernelSet llvm::omp::getDeviceKernels(Module &M) {
4634   // TODO: Create a more cross-platform way of determining device kernels.
4635   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
4636   KernelSet Kernels;
4637 
4638   if (!MD)
4639     return Kernels;
4640 
4641   for (auto *Op : MD->operands()) {
4642     if (Op->getNumOperands() < 2)
4643       continue;
4644     MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
4645     if (!KindID || KindID->getString() != "kernel")
4646       continue;
4647 
4648     Function *KernelFn =
4649         mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
4650     if (!KernelFn)
4651       continue;
4652 
4653     ++NumOpenMPTargetRegionKernels;
4654 
4655     Kernels.insert(KernelFn);
4656   }
4657 
4658   return Kernels;
4659 }
4660 
4661 bool llvm::omp::containsOpenMP(Module &M) {
4662   Metadata *MD = M.getModuleFlag("openmp");
4663   if (!MD)
4664     return false;
4665 
4666   return true;
4667 }
4668 
4669 bool llvm::omp::isOpenMPDevice(Module &M) {
4670   Metadata *MD = M.getModuleFlag("openmp-device");
4671   if (!MD)
4672     return false;
4673 
4674   return true;
4675 }
4676 
4677 char OpenMPOptCGSCCLegacyPass::ID = 0;
4678 
4679 INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
4680                       "OpenMP specific optimizations", false, false)
4681 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
4682 INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
4683                     "OpenMP specific optimizations", false, false)
4684 
4685 Pass *llvm::createOpenMPOptCGSCCLegacyPass() {
4686   return new OpenMPOptCGSCCLegacyPass();
4687 }
4688