1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpenMP specific optimizations:
10 //
11 // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12 // - Replacing globalized device memory with stack memory.
13 // - Replacing globalized device memory with shared memory.
14 // - Parallel region merging.
15 // - Transforming generic-mode device kernels to SPMD mode.
16 // - Specializing the state machine for generic-mode device kernels.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "llvm/Transforms/IPO/OpenMPOpt.h"
21 
22 #include "llvm/ADT/EnumeratedArray.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/Statistic.h"
25 #include "llvm/Analysis/CallGraph.h"
26 #include "llvm/Analysis/CallGraphSCCPass.h"
27 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
28 #include "llvm/Analysis/ValueTracking.h"
29 #include "llvm/Frontend/OpenMP/OMPConstants.h"
30 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
31 #include "llvm/IR/Assumptions.h"
32 #include "llvm/IR/DiagnosticInfo.h"
33 #include "llvm/IR/GlobalValue.h"
34 #include "llvm/IR/Instruction.h"
35 #include "llvm/IR/IntrinsicInst.h"
36 #include "llvm/InitializePasses.h"
37 #include "llvm/Support/CommandLine.h"
38 #include "llvm/Transforms/IPO.h"
39 #include "llvm/Transforms/IPO/Attributor.h"
40 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
41 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
42 #include "llvm/Transforms/Utils/CodeExtractor.h"
43 
44 using namespace llvm;
45 using namespace omp;
46 
47 #define DEBUG_TYPE "openmp-opt"
48 
49 static cl::opt<bool> DisableOpenMPOptimizations(
50     "openmp-opt-disable", cl::ZeroOrMore,
51     cl::desc("Disable OpenMP specific optimizations."), cl::Hidden,
52     cl::init(false));
53 
54 static cl::opt<bool> EnableParallelRegionMerging(
55     "openmp-opt-enable-merging", cl::ZeroOrMore,
56     cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
57     cl::init(false));
58 
59 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
60                                     cl::Hidden);
61 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
62                                         cl::init(false), cl::Hidden);
63 
64 static cl::opt<bool> HideMemoryTransferLatency(
65     "openmp-hide-memory-transfer-latency",
66     cl::desc("[WIP] Tries to hide the latency of host to device memory"
67              " transfers"),
68     cl::Hidden, cl::init(false));
69 
70 STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
71           "Number of OpenMP runtime calls deduplicated");
72 STATISTIC(NumOpenMPParallelRegionsDeleted,
73           "Number of OpenMP parallel regions deleted");
74 STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
75           "Number of OpenMP runtime functions identified");
76 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
77           "Number of OpenMP runtime function uses identified");
78 STATISTIC(NumOpenMPTargetRegionKernels,
79           "Number of OpenMP target region entry points (=kernels) identified");
80 STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
81           "Number of OpenMP target region entry points (=kernels) executed in "
82           "SPMD-mode instead of generic-mode");
83 STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
84           "Number of OpenMP target region entry points (=kernels) executed in "
85           "generic-mode without a state machines");
86 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
87           "Number of OpenMP target region entry points (=kernels) executed in "
88           "generic-mode with customized state machines with fallback");
89 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
90           "Number of OpenMP target region entry points (=kernels) executed in "
91           "generic-mode with customized state machines without fallback");
92 STATISTIC(
93     NumOpenMPParallelRegionsReplacedInGPUStateMachine,
94     "Number of OpenMP parallel regions replaced with ID in GPU state machines");
95 STATISTIC(NumOpenMPParallelRegionsMerged,
96           "Number of OpenMP parallel regions merged");
97 STATISTIC(NumBytesMovedToSharedMemory,
98           "Amount of memory pushed to shared memory");
99 
100 #if !defined(NDEBUG)
101 static constexpr auto TAG = "[" DEBUG_TYPE "]";
102 #endif
103 
104 namespace {
105 
106 enum class AddressSpace : unsigned {
107   Generic = 0,
108   Global = 1,
109   Shared = 3,
110   Constant = 4,
111   Local = 5,
112 };
113 
114 struct AAHeapToShared;
115 
116 struct AAICVTracker;
117 
118 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
119 /// Attributor runs.
120 struct OMPInformationCache : public InformationCache {
121   OMPInformationCache(Module &M, AnalysisGetter &AG,
122                       BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC,
123                       SmallPtrSetImpl<Kernel> &Kernels)
124       : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M),
125         Kernels(Kernels) {
126 
127     OMPBuilder.initialize();
128     initializeRuntimeFunctions();
129     initializeInternalControlVars();
130   }
131 
132   /// Generic information that describes an internal control variable.
133   struct InternalControlVarInfo {
134     /// The kind, as described by InternalControlVar enum.
135     InternalControlVar Kind;
136 
137     /// The name of the ICV.
138     StringRef Name;
139 
140     /// Environment variable associated with this ICV.
141     StringRef EnvVarName;
142 
143     /// Initial value kind.
144     ICVInitValue InitKind;
145 
146     /// Initial value.
147     ConstantInt *InitValue;
148 
149     /// Setter RTL function associated with this ICV.
150     RuntimeFunction Setter;
151 
152     /// Getter RTL function associated with this ICV.
153     RuntimeFunction Getter;
154 
155     /// RTL Function corresponding to the override clause of this ICV
156     RuntimeFunction Clause;
157   };
158 
159   /// Generic information that describes a runtime function
160   struct RuntimeFunctionInfo {
161 
162     /// The kind, as described by the RuntimeFunction enum.
163     RuntimeFunction Kind;
164 
165     /// The name of the function.
166     StringRef Name;
167 
168     /// Flag to indicate a variadic function.
169     bool IsVarArg;
170 
171     /// The return type of the function.
172     Type *ReturnType;
173 
174     /// The argument types of the function.
175     SmallVector<Type *, 8> ArgumentTypes;
176 
177     /// The declaration if available.
178     Function *Declaration = nullptr;
179 
180     /// Uses of this runtime function per function containing the use.
181     using UseVector = SmallVector<Use *, 16>;
182 
183     /// Clear UsesMap for runtime function.
184     void clearUsesMap() { UsesMap.clear(); }
185 
186     /// Boolean conversion that is true if the runtime function was found.
187     operator bool() const { return Declaration; }
188 
189     /// Return the vector of uses in function \p F.
190     UseVector &getOrCreateUseVector(Function *F) {
191       std::shared_ptr<UseVector> &UV = UsesMap[F];
192       if (!UV)
193         UV = std::make_shared<UseVector>();
194       return *UV;
195     }
196 
197     /// Return the vector of uses in function \p F or `nullptr` if there are
198     /// none.
199     const UseVector *getUseVector(Function &F) const {
200       auto I = UsesMap.find(&F);
201       if (I != UsesMap.end())
202         return I->second.get();
203       return nullptr;
204     }
205 
206     /// Return how many functions contain uses of this runtime function.
207     size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
208 
209     /// Return the number of arguments (or the minimal number for variadic
210     /// functions).
211     size_t getNumArgs() const { return ArgumentTypes.size(); }
212 
213     /// Run the callback \p CB on each use and forget the use if the result is
214     /// true. The callback will be fed the function in which the use was
215     /// encountered as second argument.
216     void foreachUse(SmallVectorImpl<Function *> &SCC,
217                     function_ref<bool(Use &, Function &)> CB) {
218       for (Function *F : SCC)
219         foreachUse(CB, F);
220     }
221 
222     /// Run the callback \p CB on each use within the function \p F and forget
223     /// the use if the result is true.
224     void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
225       SmallVector<unsigned, 8> ToBeDeleted;
226       ToBeDeleted.clear();
227 
228       unsigned Idx = 0;
229       UseVector &UV = getOrCreateUseVector(F);
230 
231       for (Use *U : UV) {
232         if (CB(*U, *F))
233           ToBeDeleted.push_back(Idx);
234         ++Idx;
235       }
236 
237       // Remove the to-be-deleted indices in reverse order as prior
238       // modifications will not modify the smaller indices.
239       while (!ToBeDeleted.empty()) {
240         unsigned Idx = ToBeDeleted.pop_back_val();
241         UV[Idx] = UV.back();
242         UV.pop_back();
243       }
244     }
245 
246   private:
247     /// Map from functions to all uses of this runtime function contained in
248     /// them.
249     DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
250 
251   public:
252     /// Iterators for the uses of this runtime function.
253     decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
254     decltype(UsesMap)::iterator end() { return UsesMap.end(); }
255   };
256 
257   /// An OpenMP-IR-Builder instance
258   OpenMPIRBuilder OMPBuilder;
259 
260   /// Map from runtime function kind to the runtime function description.
261   EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
262                   RuntimeFunction::OMPRTL___last>
263       RFIs;
264 
265   /// Map from function declarations/definitions to their runtime enum type.
266   DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
267 
268   /// Map from ICV kind to the ICV description.
269   EnumeratedArray<InternalControlVarInfo, InternalControlVar,
270                   InternalControlVar::ICV___last>
271       ICVs;
272 
273   /// Helper to initialize all internal control variable information for those
274   /// defined in OMPKinds.def.
275   void initializeInternalControlVars() {
276 #define ICV_RT_SET(_Name, RTL)                                                 \
277   {                                                                            \
278     auto &ICV = ICVs[_Name];                                                   \
279     ICV.Setter = RTL;                                                          \
280   }
281 #define ICV_RT_GET(Name, RTL)                                                  \
282   {                                                                            \
283     auto &ICV = ICVs[Name];                                                    \
284     ICV.Getter = RTL;                                                          \
285   }
286 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init)                           \
287   {                                                                            \
288     auto &ICV = ICVs[Enum];                                                    \
289     ICV.Name = _Name;                                                          \
290     ICV.Kind = Enum;                                                           \
291     ICV.InitKind = Init;                                                       \
292     ICV.EnvVarName = _EnvVarName;                                              \
293     switch (ICV.InitKind) {                                                    \
294     case ICV_IMPLEMENTATION_DEFINED:                                           \
295       ICV.InitValue = nullptr;                                                 \
296       break;                                                                   \
297     case ICV_ZERO:                                                             \
298       ICV.InitValue = ConstantInt::get(                                        \
299           Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0);                \
300       break;                                                                   \
301     case ICV_FALSE:                                                            \
302       ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext());    \
303       break;                                                                   \
304     case ICV_LAST:                                                             \
305       break;                                                                   \
306     }                                                                          \
307   }
308 #include "llvm/Frontend/OpenMP/OMPKinds.def"
309   }
310 
311   /// Returns true if the function declaration \p F matches the runtime
312   /// function types, that is, return type \p RTFRetType, and argument types
313   /// \p RTFArgTypes.
314   static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
315                                   SmallVector<Type *, 8> &RTFArgTypes) {
316     // TODO: We should output information to the user (under debug output
317     //       and via remarks).
318 
319     if (!F)
320       return false;
321     if (F->getReturnType() != RTFRetType)
322       return false;
323     if (F->arg_size() != RTFArgTypes.size())
324       return false;
325 
326     auto RTFTyIt = RTFArgTypes.begin();
327     for (Argument &Arg : F->args()) {
328       if (Arg.getType() != *RTFTyIt)
329         return false;
330 
331       ++RTFTyIt;
332     }
333 
334     return true;
335   }
336 
337   // Helper to collect all uses of the declaration in the UsesMap.
338   unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
339     unsigned NumUses = 0;
340     if (!RFI.Declaration)
341       return NumUses;
342     OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
343 
344     if (CollectStats) {
345       NumOpenMPRuntimeFunctionsIdentified += 1;
346       NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
347     }
348 
349     // TODO: We directly convert uses into proper calls and unknown uses.
350     for (Use &U : RFI.Declaration->uses()) {
351       if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
352         if (ModuleSlice.count(UserI->getFunction())) {
353           RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
354           ++NumUses;
355         }
356       } else {
357         RFI.getOrCreateUseVector(nullptr).push_back(&U);
358         ++NumUses;
359       }
360     }
361     return NumUses;
362   }
363 
364   // Helper function to recollect uses of a runtime function.
365   void recollectUsesForFunction(RuntimeFunction RTF) {
366     auto &RFI = RFIs[RTF];
367     RFI.clearUsesMap();
368     collectUses(RFI, /*CollectStats*/ false);
369   }
370 
371   // Helper function to recollect uses of all runtime functions.
372   void recollectUses() {
373     for (int Idx = 0; Idx < RFIs.size(); ++Idx)
374       recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
375   }
376 
377   /// Helper to initialize all runtime function information for those defined
378   /// in OpenMPKinds.def.
379   void initializeRuntimeFunctions() {
380     Module &M = *((*ModuleSlice.begin())->getParent());
381 
382     // Helper macros for handling __VA_ARGS__ in OMP_RTL
383 #define OMP_TYPE(VarName, ...)                                                 \
384   Type *VarName = OMPBuilder.VarName;                                          \
385   (void)VarName;
386 
387 #define OMP_ARRAY_TYPE(VarName, ...)                                           \
388   ArrayType *VarName##Ty = OMPBuilder.VarName##Ty;                             \
389   (void)VarName##Ty;                                                           \
390   PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy;                     \
391   (void)VarName##PtrTy;
392 
393 #define OMP_FUNCTION_TYPE(VarName, ...)                                        \
394   FunctionType *VarName = OMPBuilder.VarName;                                  \
395   (void)VarName;                                                               \
396   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
397   (void)VarName##Ptr;
398 
399 #define OMP_STRUCT_TYPE(VarName, ...)                                          \
400   StructType *VarName = OMPBuilder.VarName;                                    \
401   (void)VarName;                                                               \
402   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
403   (void)VarName##Ptr;
404 
405 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
406   {                                                                            \
407     SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
408     Function *F = M.getFunction(_Name);                                        \
409     RTLFunctions.insert(F);                                                    \
410     if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) {           \
411       RuntimeFunctionIDMap[F] = _Enum;                                         \
412       auto &RFI = RFIs[_Enum];                                                 \
413       RFI.Kind = _Enum;                                                        \
414       RFI.Name = _Name;                                                        \
415       RFI.IsVarArg = _IsVarArg;                                                \
416       RFI.ReturnType = OMPBuilder._ReturnType;                                 \
417       RFI.ArgumentTypes = std::move(ArgsTypes);                                \
418       RFI.Declaration = F;                                                     \
419       unsigned NumUses = collectUses(RFI);                                     \
420       (void)NumUses;                                                           \
421       LLVM_DEBUG({                                                             \
422         dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
423                << " found\n";                                                  \
424         if (RFI.Declaration)                                                   \
425           dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
426                  << RFI.getNumFunctionsWithUses()                              \
427                  << " different functions.\n";                                 \
428       });                                                                      \
429     }                                                                          \
430   }
431 #include "llvm/Frontend/OpenMP/OMPKinds.def"
432 
433     // TODO: We should attach the attributes defined in OMPKinds.def.
434   }
435 
436   /// Collection of known kernels (\see Kernel) in the module.
437   SmallPtrSetImpl<Kernel> &Kernels;
438 
439   /// Collection of known OpenMP runtime functions..
440   DenseSet<const Function *> RTLFunctions;
441 };
442 
443 template <typename Ty, bool InsertInvalidates = true>
444 struct BooleanStateWithPtrSetVector : public BooleanState {
445 
446   bool contains(Ty *Elem) const { return Set.contains(Elem); }
447   bool insert(Ty *Elem) {
448     if (InsertInvalidates)
449       BooleanState::indicatePessimisticFixpoint();
450     return Set.insert(Elem);
451   }
452 
453   Ty *operator[](int Idx) const { return Set[Idx]; }
454   bool operator==(const BooleanStateWithPtrSetVector &RHS) const {
455     return BooleanState::operator==(RHS) && Set == RHS.Set;
456   }
457   bool operator!=(const BooleanStateWithPtrSetVector &RHS) const {
458     return !(*this == RHS);
459   }
460 
461   bool empty() const { return Set.empty(); }
462   size_t size() const { return Set.size(); }
463 
464   /// "Clamp" this state with \p RHS.
465   BooleanStateWithPtrSetVector &
466   operator^=(const BooleanStateWithPtrSetVector &RHS) {
467     BooleanState::operator^=(RHS);
468     Set.insert(RHS.Set.begin(), RHS.Set.end());
469     return *this;
470   }
471 
472 private:
473   /// A set to keep track of elements.
474   SetVector<Ty *> Set;
475 
476 public:
477   typename decltype(Set)::iterator begin() { return Set.begin(); }
478   typename decltype(Set)::iterator end() { return Set.end(); }
479   typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
480   typename decltype(Set)::const_iterator end() const { return Set.end(); }
481 };
482 
483 struct KernelInfoState : AbstractState {
484   /// Flag to track if we reached a fixpoint.
485   bool IsAtFixpoint = false;
486 
487   /// The parallel regions (identified by the outlined parallel functions) that
488   /// can be reached from the associated function.
489   BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false>
490       ReachedKnownParallelRegions;
491 
492   /// State to track what parallel region we might reach.
493   BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
494 
495   /// State to track if we are in SPMD-mode, assumed or know, and why we decided
496   /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
497   /// false.
498   BooleanStateWithPtrSetVector<Instruction> SPMDCompatibilityTracker;
499 
500   /// The __kmpc_target_init call in this kernel, if any. If we find more than
501   /// one we abort as the kernel is malformed.
502   CallBase *KernelInitCB = nullptr;
503 
504   /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
505   /// one we abort as the kernel is malformed.
506   CallBase *KernelDeinitCB = nullptr;
507 
508   /// Flag to indicate if the associated function is a kernel entry.
509   bool IsKernelEntry = false;
510 
511   /// State to track what kernel entries can reach the associated function.
512   BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
513 
514   /// Abstract State interface
515   ///{
516 
517   KernelInfoState() {}
518   KernelInfoState(bool BestState) {
519     if (!BestState)
520       indicatePessimisticFixpoint();
521   }
522 
523   /// See AbstractState::isValidState(...)
524   bool isValidState() const override { return true; }
525 
526   /// See AbstractState::isAtFixpoint(...)
527   bool isAtFixpoint() const override { return IsAtFixpoint; }
528 
529   /// See AbstractState::indicatePessimisticFixpoint(...)
530   ChangeStatus indicatePessimisticFixpoint() override {
531     IsAtFixpoint = true;
532     SPMDCompatibilityTracker.indicatePessimisticFixpoint();
533     ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
534     return ChangeStatus::CHANGED;
535   }
536 
537   /// See AbstractState::indicateOptimisticFixpoint(...)
538   ChangeStatus indicateOptimisticFixpoint() override {
539     IsAtFixpoint = true;
540     return ChangeStatus::UNCHANGED;
541   }
542 
543   /// Return the assumed state
544   KernelInfoState &getAssumed() { return *this; }
545   const KernelInfoState &getAssumed() const { return *this; }
546 
547   bool operator==(const KernelInfoState &RHS) const {
548     if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
549       return false;
550     if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
551       return false;
552     if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
553       return false;
554     if (ReachingKernelEntries != RHS.ReachingKernelEntries)
555       return false;
556     return true;
557   }
558 
559   /// Return empty set as the best state of potential values.
560   static KernelInfoState getBestState() { return KernelInfoState(true); }
561 
562   static KernelInfoState getBestState(KernelInfoState &KIS) {
563     return getBestState();
564   }
565 
566   /// Return full set as the worst state of potential values.
567   static KernelInfoState getWorstState() { return KernelInfoState(false); }
568 
569   /// "Clamp" this state with \p KIS.
570   KernelInfoState operator^=(const KernelInfoState &KIS) {
571     // Do not merge two different _init and _deinit call sites.
572     if (KIS.KernelInitCB) {
573       if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
574         indicatePessimisticFixpoint();
575       KernelInitCB = KIS.KernelInitCB;
576     }
577     if (KIS.KernelDeinitCB) {
578       if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
579         indicatePessimisticFixpoint();
580       KernelDeinitCB = KIS.KernelDeinitCB;
581     }
582     SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
583     ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
584     ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
585     return *this;
586   }
587 
588   KernelInfoState operator&=(const KernelInfoState &KIS) {
589     return (*this ^= KIS);
590   }
591 
592   ///}
593 };
594 
595 /// Used to map the values physically (in the IR) stored in an offload
596 /// array, to a vector in memory.
597 struct OffloadArray {
598   /// Physical array (in the IR).
599   AllocaInst *Array = nullptr;
600   /// Mapped values.
601   SmallVector<Value *, 8> StoredValues;
602   /// Last stores made in the offload array.
603   SmallVector<StoreInst *, 8> LastAccesses;
604 
605   OffloadArray() = default;
606 
607   /// Initializes the OffloadArray with the values stored in \p Array before
608   /// instruction \p Before is reached. Returns false if the initialization
609   /// fails.
610   /// This MUST be used immediately after the construction of the object.
611   bool initialize(AllocaInst &Array, Instruction &Before) {
612     if (!Array.getAllocatedType()->isArrayTy())
613       return false;
614 
615     if (!getValues(Array, Before))
616       return false;
617 
618     this->Array = &Array;
619     return true;
620   }
621 
622   static const unsigned DeviceIDArgNum = 1;
623   static const unsigned BasePtrsArgNum = 3;
624   static const unsigned PtrsArgNum = 4;
625   static const unsigned SizesArgNum = 5;
626 
627 private:
628   /// Traverses the BasicBlock where \p Array is, collecting the stores made to
629   /// \p Array, leaving StoredValues with the values stored before the
630   /// instruction \p Before is reached.
631   bool getValues(AllocaInst &Array, Instruction &Before) {
632     // Initialize container.
633     const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
634     StoredValues.assign(NumValues, nullptr);
635     LastAccesses.assign(NumValues, nullptr);
636 
637     // TODO: This assumes the instruction \p Before is in the same
638     //  BasicBlock as Array. Make it general, for any control flow graph.
639     BasicBlock *BB = Array.getParent();
640     if (BB != Before.getParent())
641       return false;
642 
643     const DataLayout &DL = Array.getModule()->getDataLayout();
644     const unsigned int PointerSize = DL.getPointerSize();
645 
646     for (Instruction &I : *BB) {
647       if (&I == &Before)
648         break;
649 
650       if (!isa<StoreInst>(&I))
651         continue;
652 
653       auto *S = cast<StoreInst>(&I);
654       int64_t Offset = -1;
655       auto *Dst =
656           GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
657       if (Dst == &Array) {
658         int64_t Idx = Offset / PointerSize;
659         StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
660         LastAccesses[Idx] = S;
661       }
662     }
663 
664     return isFilled();
665   }
666 
667   /// Returns true if all values in StoredValues and
668   /// LastAccesses are not nullptrs.
669   bool isFilled() {
670     const unsigned NumValues = StoredValues.size();
671     for (unsigned I = 0; I < NumValues; ++I) {
672       if (!StoredValues[I] || !LastAccesses[I])
673         return false;
674     }
675 
676     return true;
677   }
678 };
679 
680 struct OpenMPOpt {
681 
682   using OptimizationRemarkGetter =
683       function_ref<OptimizationRemarkEmitter &(Function *)>;
684 
685   OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
686             OptimizationRemarkGetter OREGetter,
687             OMPInformationCache &OMPInfoCache, Attributor &A)
688       : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
689         OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
690 
691   /// Check if any remarks are enabled for openmp-opt
692   bool remarksEnabled() {
693     auto &Ctx = M.getContext();
694     return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
695   }
696 
697   /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
698   bool run(bool IsModulePass) {
699     if (SCC.empty())
700       return false;
701 
702     bool Changed = false;
703 
704     LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
705                       << " functions in a slice with "
706                       << OMPInfoCache.ModuleSlice.size() << " functions\n");
707 
708     if (IsModulePass) {
709       Changed |= runAttributor(IsModulePass);
710 
711       // Recollect uses, in case Attributor deleted any.
712       OMPInfoCache.recollectUses();
713 
714       if (remarksEnabled())
715         analysisGlobalization();
716     } else {
717       if (PrintICVValues)
718         printICVs();
719       if (PrintOpenMPKernels)
720         printKernels();
721 
722       Changed |= runAttributor(IsModulePass);
723 
724       // Recollect uses, in case Attributor deleted any.
725       OMPInfoCache.recollectUses();
726 
727       Changed |= deleteParallelRegions();
728       Changed |= rewriteDeviceCodeStateMachine();
729 
730       if (HideMemoryTransferLatency)
731         Changed |= hideMemTransfersLatency();
732       Changed |= deduplicateRuntimeCalls();
733       if (EnableParallelRegionMerging) {
734         if (mergeParallelRegions()) {
735           deduplicateRuntimeCalls();
736           Changed = true;
737         }
738       }
739     }
740 
741     return Changed;
742   }
743 
744   /// Print initial ICV values for testing.
745   /// FIXME: This should be done from the Attributor once it is added.
746   void printICVs() const {
747     InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
748                                  ICV_proc_bind};
749 
750     for (Function *F : OMPInfoCache.ModuleSlice) {
751       for (auto ICV : ICVs) {
752         auto ICVInfo = OMPInfoCache.ICVs[ICV];
753         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
754           return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
755                      << " Value: "
756                      << (ICVInfo.InitValue
757                              ? toString(ICVInfo.InitValue->getValue(), 10, true)
758                              : "IMPLEMENTATION_DEFINED");
759         };
760 
761         emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
762       }
763     }
764   }
765 
766   /// Print OpenMP GPU kernels for testing.
767   void printKernels() const {
768     for (Function *F : SCC) {
769       if (!OMPInfoCache.Kernels.count(F))
770         continue;
771 
772       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
773         return ORA << "OpenMP GPU kernel "
774                    << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
775       };
776 
777       emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
778     }
779   }
780 
781   /// Return the call if \p U is a callee use in a regular call. If \p RFI is
782   /// given it has to be the callee or a nullptr is returned.
783   static CallInst *getCallIfRegularCall(
784       Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
785     CallInst *CI = dyn_cast<CallInst>(U.getUser());
786     if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
787         (!RFI || CI->getCalledFunction() == RFI->Declaration))
788       return CI;
789     return nullptr;
790   }
791 
792   /// Return the call if \p V is a regular call. If \p RFI is given it has to be
793   /// the callee or a nullptr is returned.
794   static CallInst *getCallIfRegularCall(
795       Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
796     CallInst *CI = dyn_cast<CallInst>(&V);
797     if (CI && !CI->hasOperandBundles() &&
798         (!RFI || CI->getCalledFunction() == RFI->Declaration))
799       return CI;
800     return nullptr;
801   }
802 
803 private:
804   /// Merge parallel regions when it is safe.
805   bool mergeParallelRegions() {
806     const unsigned CallbackCalleeOperand = 2;
807     const unsigned CallbackFirstArgOperand = 3;
808     using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
809 
810     // Check if there are any __kmpc_fork_call calls to merge.
811     OMPInformationCache::RuntimeFunctionInfo &RFI =
812         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
813 
814     if (!RFI.Declaration)
815       return false;
816 
817     // Unmergable calls that prevent merging a parallel region.
818     OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
819         OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
820         OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
821     };
822 
823     bool Changed = false;
824     LoopInfo *LI = nullptr;
825     DominatorTree *DT = nullptr;
826 
827     SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
828 
829     BasicBlock *StartBB = nullptr, *EndBB = nullptr;
830     auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
831                          BasicBlock &ContinuationIP) {
832       BasicBlock *CGStartBB = CodeGenIP.getBlock();
833       BasicBlock *CGEndBB =
834           SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
835       assert(StartBB != nullptr && "StartBB should not be null");
836       CGStartBB->getTerminator()->setSuccessor(0, StartBB);
837       assert(EndBB != nullptr && "EndBB should not be null");
838       EndBB->getTerminator()->setSuccessor(0, CGEndBB);
839     };
840 
841     auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
842                       Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
843       ReplacementValue = &Inner;
844       return CodeGenIP;
845     };
846 
847     auto FiniCB = [&](InsertPointTy CodeGenIP) {};
848 
849     /// Create a sequential execution region within a merged parallel region,
850     /// encapsulated in a master construct with a barrier for synchronization.
851     auto CreateSequentialRegion = [&](Function *OuterFn,
852                                       BasicBlock *OuterPredBB,
853                                       Instruction *SeqStartI,
854                                       Instruction *SeqEndI) {
855       // Isolate the instructions of the sequential region to a separate
856       // block.
857       BasicBlock *ParentBB = SeqStartI->getParent();
858       BasicBlock *SeqEndBB =
859           SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
860       BasicBlock *SeqAfterBB =
861           SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
862       BasicBlock *SeqStartBB =
863           SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
864 
865       assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
866              "Expected a different CFG");
867       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
868       ParentBB->getTerminator()->eraseFromParent();
869 
870       auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
871                            BasicBlock &ContinuationIP) {
872         BasicBlock *CGStartBB = CodeGenIP.getBlock();
873         BasicBlock *CGEndBB =
874             SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
875         assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
876         CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
877         assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
878         SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
879       };
880       auto FiniCB = [&](InsertPointTy CodeGenIP) {};
881 
882       // Find outputs from the sequential region to outside users and
883       // broadcast their values to them.
884       for (Instruction &I : *SeqStartBB) {
885         SmallPtrSet<Instruction *, 4> OutsideUsers;
886         for (User *Usr : I.users()) {
887           Instruction &UsrI = *cast<Instruction>(Usr);
888           // Ignore outputs to LT intrinsics, code extraction for the merged
889           // parallel region will fix them.
890           if (UsrI.isLifetimeStartOrEnd())
891             continue;
892 
893           if (UsrI.getParent() != SeqStartBB)
894             OutsideUsers.insert(&UsrI);
895         }
896 
897         if (OutsideUsers.empty())
898           continue;
899 
900         // Emit an alloca in the outer region to store the broadcasted
901         // value.
902         const DataLayout &DL = M.getDataLayout();
903         AllocaInst *AllocaI = new AllocaInst(
904             I.getType(), DL.getAllocaAddrSpace(), nullptr,
905             I.getName() + ".seq.output.alloc", &OuterFn->front().front());
906 
907         // Emit a store instruction in the sequential BB to update the
908         // value.
909         new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
910 
911         // Emit a load instruction and replace the use of the output value
912         // with it.
913         for (Instruction *UsrI : OutsideUsers) {
914           LoadInst *LoadI = new LoadInst(
915               I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
916           UsrI->replaceUsesOfWith(&I, LoadI);
917         }
918       }
919 
920       OpenMPIRBuilder::LocationDescription Loc(
921           InsertPointTy(ParentBB, ParentBB->end()), DL);
922       InsertPointTy SeqAfterIP =
923           OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
924 
925       OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
926 
927       BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
928 
929       LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
930                         << "\n");
931     };
932 
933     // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
934     // contained in BB and only separated by instructions that can be
935     // redundantly executed in parallel. The block BB is split before the first
936     // call (in MergableCIs) and after the last so the entire region we merge
937     // into a single parallel region is contained in a single basic block
938     // without any other instructions. We use the OpenMPIRBuilder to outline
939     // that block and call the resulting function via __kmpc_fork_call.
940     auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) {
941       // TODO: Change the interface to allow single CIs expanded, e.g, to
942       // include an outer loop.
943       assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
944 
945       auto Remark = [&](OptimizationRemark OR) {
946         OR << "Parallel region merged with parallel region"
947            << (MergableCIs.size() > 2 ? "s" : "") << " at ";
948         for (auto *CI : llvm::drop_begin(MergableCIs)) {
949           OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
950           if (CI != MergableCIs.back())
951             OR << ", ";
952         }
953         return OR << ".";
954       };
955 
956       emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
957 
958       Function *OriginalFn = BB->getParent();
959       LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
960                         << " parallel regions in " << OriginalFn->getName()
961                         << "\n");
962 
963       // Isolate the calls to merge in a separate block.
964       EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
965       BasicBlock *AfterBB =
966           SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
967       StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
968                            "omp.par.merged");
969 
970       assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
971       const DebugLoc DL = BB->getTerminator()->getDebugLoc();
972       BB->getTerminator()->eraseFromParent();
973 
974       // Create sequential regions for sequential instructions that are
975       // in-between mergable parallel regions.
976       for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
977            It != End; ++It) {
978         Instruction *ForkCI = *It;
979         Instruction *NextForkCI = *(It + 1);
980 
981         // Continue if there are not in-between instructions.
982         if (ForkCI->getNextNode() == NextForkCI)
983           continue;
984 
985         CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
986                                NextForkCI->getPrevNode());
987       }
988 
989       OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
990                                                DL);
991       IRBuilder<>::InsertPoint AllocaIP(
992           &OriginalFn->getEntryBlock(),
993           OriginalFn->getEntryBlock().getFirstInsertionPt());
994       // Create the merged parallel region with default proc binding, to
995       // avoid overriding binding settings, and without explicit cancellation.
996       InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
997           Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
998           OMP_PROC_BIND_default, /* IsCancellable */ false);
999       BranchInst::Create(AfterBB, AfterIP.getBlock());
1000 
1001       // Perform the actual outlining.
1002       OMPInfoCache.OMPBuilder.finalize(OriginalFn,
1003                                        /* AllowExtractorSinking */ true);
1004 
1005       Function *OutlinedFn = MergableCIs.front()->getCaller();
1006 
1007       // Replace the __kmpc_fork_call calls with direct calls to the outlined
1008       // callbacks.
1009       SmallVector<Value *, 8> Args;
1010       for (auto *CI : MergableCIs) {
1011         Value *Callee =
1012             CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts();
1013         FunctionType *FT =
1014             cast<FunctionType>(Callee->getType()->getPointerElementType());
1015         Args.clear();
1016         Args.push_back(OutlinedFn->getArg(0));
1017         Args.push_back(OutlinedFn->getArg(1));
1018         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
1019              U < E; ++U)
1020           Args.push_back(CI->getArgOperand(U));
1021 
1022         CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
1023         if (CI->getDebugLoc())
1024           NewCI->setDebugLoc(CI->getDebugLoc());
1025 
1026         // Forward parameter attributes from the callback to the callee.
1027         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
1028              U < E; ++U)
1029           for (const Attribute &A : CI->getAttributes().getParamAttributes(U))
1030             NewCI->addParamAttr(
1031                 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1032 
1033         // Emit an explicit barrier to replace the implicit fork-join barrier.
1034         if (CI != MergableCIs.back()) {
1035           // TODO: Remove barrier if the merged parallel region includes the
1036           // 'nowait' clause.
1037           OMPInfoCache.OMPBuilder.createBarrier(
1038               InsertPointTy(NewCI->getParent(),
1039                             NewCI->getNextNode()->getIterator()),
1040               OMPD_parallel);
1041         }
1042 
1043         CI->eraseFromParent();
1044       }
1045 
1046       assert(OutlinedFn != OriginalFn && "Outlining failed");
1047       CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1048       CGUpdater.reanalyzeFunction(*OriginalFn);
1049 
1050       NumOpenMPParallelRegionsMerged += MergableCIs.size();
1051 
1052       return true;
1053     };
1054 
1055     // Helper function that identifes sequences of
1056     // __kmpc_fork_call uses in a basic block.
1057     auto DetectPRsCB = [&](Use &U, Function &F) {
1058       CallInst *CI = getCallIfRegularCall(U, &RFI);
1059       BB2PRMap[CI->getParent()].insert(CI);
1060 
1061       return false;
1062     };
1063 
1064     BB2PRMap.clear();
1065     RFI.foreachUse(SCC, DetectPRsCB);
1066     SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1067     // Find mergable parallel regions within a basic block that are
1068     // safe to merge, that is any in-between instructions can safely
1069     // execute in parallel after merging.
1070     // TODO: support merging across basic-blocks.
1071     for (auto &It : BB2PRMap) {
1072       auto &CIs = It.getSecond();
1073       if (CIs.size() < 2)
1074         continue;
1075 
1076       BasicBlock *BB = It.getFirst();
1077       SmallVector<CallInst *, 4> MergableCIs;
1078 
1079       /// Returns true if the instruction is mergable, false otherwise.
1080       /// A terminator instruction is unmergable by definition since merging
1081       /// works within a BB. Instructions before the mergable region are
1082       /// mergable if they are not calls to OpenMP runtime functions that may
1083       /// set different execution parameters for subsequent parallel regions.
1084       /// Instructions in-between parallel regions are mergable if they are not
1085       /// calls to any non-intrinsic function since that may call a non-mergable
1086       /// OpenMP runtime function.
1087       auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1088         // We do not merge across BBs, hence return false (unmergable) if the
1089         // instruction is a terminator.
1090         if (I.isTerminator())
1091           return false;
1092 
1093         if (!isa<CallInst>(&I))
1094           return true;
1095 
1096         CallInst *CI = cast<CallInst>(&I);
1097         if (IsBeforeMergableRegion) {
1098           Function *CalledFunction = CI->getCalledFunction();
1099           if (!CalledFunction)
1100             return false;
1101           // Return false (unmergable) if the call before the parallel
1102           // region calls an explicit affinity (proc_bind) or number of
1103           // threads (num_threads) compiler-generated function. Those settings
1104           // may be incompatible with following parallel regions.
1105           // TODO: ICV tracking to detect compatibility.
1106           for (const auto &RFI : UnmergableCallsInfo) {
1107             if (CalledFunction == RFI.Declaration)
1108               return false;
1109           }
1110         } else {
1111           // Return false (unmergable) if there is a call instruction
1112           // in-between parallel regions when it is not an intrinsic. It
1113           // may call an unmergable OpenMP runtime function in its callpath.
1114           // TODO: Keep track of possible OpenMP calls in the callpath.
1115           if (!isa<IntrinsicInst>(CI))
1116             return false;
1117         }
1118 
1119         return true;
1120       };
1121       // Find maximal number of parallel region CIs that are safe to merge.
1122       for (auto It = BB->begin(), End = BB->end(); It != End;) {
1123         Instruction &I = *It;
1124         ++It;
1125 
1126         if (CIs.count(&I)) {
1127           MergableCIs.push_back(cast<CallInst>(&I));
1128           continue;
1129         }
1130 
1131         // Continue expanding if the instruction is mergable.
1132         if (IsMergable(I, MergableCIs.empty()))
1133           continue;
1134 
1135         // Forward the instruction iterator to skip the next parallel region
1136         // since there is an unmergable instruction which can affect it.
1137         for (; It != End; ++It) {
1138           Instruction &SkipI = *It;
1139           if (CIs.count(&SkipI)) {
1140             LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1141                               << " due to " << I << "\n");
1142             ++It;
1143             break;
1144           }
1145         }
1146 
1147         // Store mergable regions found.
1148         if (MergableCIs.size() > 1) {
1149           MergableCIsVector.push_back(MergableCIs);
1150           LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1151                             << " parallel regions in block " << BB->getName()
1152                             << " of function " << BB->getParent()->getName()
1153                             << "\n";);
1154         }
1155 
1156         MergableCIs.clear();
1157       }
1158 
1159       if (!MergableCIsVector.empty()) {
1160         Changed = true;
1161 
1162         for (auto &MergableCIs : MergableCIsVector)
1163           Merge(MergableCIs, BB);
1164         MergableCIsVector.clear();
1165       }
1166     }
1167 
1168     if (Changed) {
1169       /// Re-collect use for fork calls, emitted barrier calls, and
1170       /// any emitted master/end_master calls.
1171       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1172       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1173       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1174       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1175     }
1176 
1177     return Changed;
1178   }
1179 
1180   /// Try to delete parallel regions if possible.
1181   bool deleteParallelRegions() {
1182     const unsigned CallbackCalleeOperand = 2;
1183 
1184     OMPInformationCache::RuntimeFunctionInfo &RFI =
1185         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1186 
1187     if (!RFI.Declaration)
1188       return false;
1189 
1190     bool Changed = false;
1191     auto DeleteCallCB = [&](Use &U, Function &) {
1192       CallInst *CI = getCallIfRegularCall(U);
1193       if (!CI)
1194         return false;
1195       auto *Fn = dyn_cast<Function>(
1196           CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1197       if (!Fn)
1198         return false;
1199       if (!Fn->onlyReadsMemory())
1200         return false;
1201       if (!Fn->hasFnAttribute(Attribute::WillReturn))
1202         return false;
1203 
1204       LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1205                         << CI->getCaller()->getName() << "\n");
1206 
1207       auto Remark = [&](OptimizationRemark OR) {
1208         return OR << "Removing parallel region with no side-effects.";
1209       };
1210       emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1211 
1212       CGUpdater.removeCallSite(*CI);
1213       CI->eraseFromParent();
1214       Changed = true;
1215       ++NumOpenMPParallelRegionsDeleted;
1216       return true;
1217     };
1218 
1219     RFI.foreachUse(SCC, DeleteCallCB);
1220 
1221     return Changed;
1222   }
1223 
1224   /// Try to eliminate runtime calls by reusing existing ones.
1225   bool deduplicateRuntimeCalls() {
1226     bool Changed = false;
1227 
1228     RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1229         OMPRTL_omp_get_num_threads,
1230         OMPRTL_omp_in_parallel,
1231         OMPRTL_omp_get_cancellation,
1232         OMPRTL_omp_get_thread_limit,
1233         OMPRTL_omp_get_supported_active_levels,
1234         OMPRTL_omp_get_level,
1235         OMPRTL_omp_get_ancestor_thread_num,
1236         OMPRTL_omp_get_team_size,
1237         OMPRTL_omp_get_active_level,
1238         OMPRTL_omp_in_final,
1239         OMPRTL_omp_get_proc_bind,
1240         OMPRTL_omp_get_num_places,
1241         OMPRTL_omp_get_num_procs,
1242         OMPRTL_omp_get_place_num,
1243         OMPRTL_omp_get_partition_num_places,
1244         OMPRTL_omp_get_partition_place_nums};
1245 
1246     // Global-tid is handled separately.
1247     SmallSetVector<Value *, 16> GTIdArgs;
1248     collectGlobalThreadIdArguments(GTIdArgs);
1249     LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1250                       << " global thread ID arguments\n");
1251 
1252     for (Function *F : SCC) {
1253       for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1254         Changed |= deduplicateRuntimeCalls(
1255             *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1256 
1257       // __kmpc_global_thread_num is special as we can replace it with an
1258       // argument in enough cases to make it worth trying.
1259       Value *GTIdArg = nullptr;
1260       for (Argument &Arg : F->args())
1261         if (GTIdArgs.count(&Arg)) {
1262           GTIdArg = &Arg;
1263           break;
1264         }
1265       Changed |= deduplicateRuntimeCalls(
1266           *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1267     }
1268 
1269     return Changed;
1270   }
1271 
1272   /// Tries to hide the latency of runtime calls that involve host to
1273   /// device memory transfers by splitting them into their "issue" and "wait"
1274   /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1275   /// moved downards as much as possible. The "issue" issues the memory transfer
1276   /// asynchronously, returning a handle. The "wait" waits in the returned
1277   /// handle for the memory transfer to finish.
1278   bool hideMemTransfersLatency() {
1279     auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1280     bool Changed = false;
1281     auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1282       auto *RTCall = getCallIfRegularCall(U, &RFI);
1283       if (!RTCall)
1284         return false;
1285 
1286       OffloadArray OffloadArrays[3];
1287       if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1288         return false;
1289 
1290       LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1291 
1292       // TODO: Check if can be moved upwards.
1293       bool WasSplit = false;
1294       Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1295       if (WaitMovementPoint)
1296         WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1297 
1298       Changed |= WasSplit;
1299       return WasSplit;
1300     };
1301     RFI.foreachUse(SCC, SplitMemTransfers);
1302 
1303     return Changed;
1304   }
1305 
1306   void analysisGlobalization() {
1307     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1308 
1309     auto CheckGlobalization = [&](Use &U, Function &Decl) {
1310       if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1311         auto Remark = [&](OptimizationRemarkMissed ORM) {
1312           return ORM
1313                  << "Found thread data sharing on the GPU. "
1314                  << "Expect degraded performance due to data globalization.";
1315         };
1316         emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1317       }
1318 
1319       return false;
1320     };
1321 
1322     RFI.foreachUse(SCC, CheckGlobalization);
1323   }
1324 
1325   /// Maps the values stored in the offload arrays passed as arguments to
1326   /// \p RuntimeCall into the offload arrays in \p OAs.
1327   bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1328                                 MutableArrayRef<OffloadArray> OAs) {
1329     assert(OAs.size() == 3 && "Need space for three offload arrays!");
1330 
1331     // A runtime call that involves memory offloading looks something like:
1332     // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1333     //   i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1334     // ...)
1335     // So, the idea is to access the allocas that allocate space for these
1336     // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1337     // Therefore:
1338     // i8** %offload_baseptrs.
1339     Value *BasePtrsArg =
1340         RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1341     // i8** %offload_ptrs.
1342     Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1343     // i8** %offload_sizes.
1344     Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1345 
1346     // Get values stored in **offload_baseptrs.
1347     auto *V = getUnderlyingObject(BasePtrsArg);
1348     if (!isa<AllocaInst>(V))
1349       return false;
1350     auto *BasePtrsArray = cast<AllocaInst>(V);
1351     if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1352       return false;
1353 
1354     // Get values stored in **offload_baseptrs.
1355     V = getUnderlyingObject(PtrsArg);
1356     if (!isa<AllocaInst>(V))
1357       return false;
1358     auto *PtrsArray = cast<AllocaInst>(V);
1359     if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1360       return false;
1361 
1362     // Get values stored in **offload_sizes.
1363     V = getUnderlyingObject(SizesArg);
1364     // If it's a [constant] global array don't analyze it.
1365     if (isa<GlobalValue>(V))
1366       return isa<Constant>(V);
1367     if (!isa<AllocaInst>(V))
1368       return false;
1369 
1370     auto *SizesArray = cast<AllocaInst>(V);
1371     if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1372       return false;
1373 
1374     return true;
1375   }
1376 
1377   /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1378   /// For now this is a way to test that the function getValuesInOffloadArrays
1379   /// is working properly.
1380   /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1381   void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1382     assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1383 
1384     LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1385     std::string ValuesStr;
1386     raw_string_ostream Printer(ValuesStr);
1387     std::string Separator = " --- ";
1388 
1389     for (auto *BP : OAs[0].StoredValues) {
1390       BP->print(Printer);
1391       Printer << Separator;
1392     }
1393     LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
1394     ValuesStr.clear();
1395 
1396     for (auto *P : OAs[1].StoredValues) {
1397       P->print(Printer);
1398       Printer << Separator;
1399     }
1400     LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
1401     ValuesStr.clear();
1402 
1403     for (auto *S : OAs[2].StoredValues) {
1404       S->print(Printer);
1405       Printer << Separator;
1406     }
1407     LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
1408   }
1409 
1410   /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1411   /// moved. Returns nullptr if the movement is not possible, or not worth it.
1412   Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1413     // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1414     //  Make it traverse the CFG.
1415 
1416     Instruction *CurrentI = &RuntimeCall;
1417     bool IsWorthIt = false;
1418     while ((CurrentI = CurrentI->getNextNode())) {
1419 
1420       // TODO: Once we detect the regions to be offloaded we should use the
1421       //  alias analysis manager to check if CurrentI may modify one of
1422       //  the offloaded regions.
1423       if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1424         if (IsWorthIt)
1425           return CurrentI;
1426 
1427         return nullptr;
1428       }
1429 
1430       // FIXME: For now if we move it over anything without side effect
1431       //  is worth it.
1432       IsWorthIt = true;
1433     }
1434 
1435     // Return end of BasicBlock.
1436     return RuntimeCall.getParent()->getTerminator();
1437   }
1438 
1439   /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1440   bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1441                                Instruction &WaitMovementPoint) {
1442     // Create stack allocated handle (__tgt_async_info) at the beginning of the
1443     // function. Used for storing information of the async transfer, allowing to
1444     // wait on it later.
1445     auto &IRBuilder = OMPInfoCache.OMPBuilder;
1446     auto *F = RuntimeCall.getCaller();
1447     Instruction *FirstInst = &(F->getEntryBlock().front());
1448     AllocaInst *Handle = new AllocaInst(
1449         IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst);
1450 
1451     // Add "issue" runtime call declaration:
1452     // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1453     //   i8**, i8**, i64*, i64*)
1454     FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1455         M, OMPRTL___tgt_target_data_begin_mapper_issue);
1456 
1457     // Change RuntimeCall call site for its asynchronous version.
1458     SmallVector<Value *, 16> Args;
1459     for (auto &Arg : RuntimeCall.args())
1460       Args.push_back(Arg.get());
1461     Args.push_back(Handle);
1462 
1463     CallInst *IssueCallsite =
1464         CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
1465     RuntimeCall.eraseFromParent();
1466 
1467     // Add "wait" runtime call declaration:
1468     // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1469     FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1470         M, OMPRTL___tgt_target_data_begin_mapper_wait);
1471 
1472     Value *WaitParams[2] = {
1473         IssueCallsite->getArgOperand(
1474             OffloadArray::DeviceIDArgNum), // device_id.
1475         Handle                             // handle to wait on.
1476     };
1477     CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
1478 
1479     return true;
1480   }
1481 
1482   static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1483                                     bool GlobalOnly, bool &SingleChoice) {
1484     if (CurrentIdent == NextIdent)
1485       return CurrentIdent;
1486 
1487     // TODO: Figure out how to actually combine multiple debug locations. For
1488     //       now we just keep an existing one if there is a single choice.
1489     if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1490       SingleChoice = !CurrentIdent;
1491       return NextIdent;
1492     }
1493     return nullptr;
1494   }
1495 
1496   /// Return an `struct ident_t*` value that represents the ones used in the
1497   /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1498   /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1499   /// return value we create one from scratch. We also do not yet combine
1500   /// information, e.g., the source locations, see combinedIdentStruct.
1501   Value *
1502   getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1503                                  Function &F, bool GlobalOnly) {
1504     bool SingleChoice = true;
1505     Value *Ident = nullptr;
1506     auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1507       CallInst *CI = getCallIfRegularCall(U, &RFI);
1508       if (!CI || &F != &Caller)
1509         return false;
1510       Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1511                                   /* GlobalOnly */ true, SingleChoice);
1512       return false;
1513     };
1514     RFI.foreachUse(SCC, CombineIdentStruct);
1515 
1516     if (!Ident || !SingleChoice) {
1517       // The IRBuilder uses the insertion block to get to the module, this is
1518       // unfortunate but we work around it for now.
1519       if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1520         OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1521             &F.getEntryBlock(), F.getEntryBlock().begin()));
1522       // Create a fallback location if non was found.
1523       // TODO: Use the debug locations of the calls instead.
1524       Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr();
1525       Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc);
1526     }
1527     return Ident;
1528   }
1529 
1530   /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1531   /// \p ReplVal if given.
1532   bool deduplicateRuntimeCalls(Function &F,
1533                                OMPInformationCache::RuntimeFunctionInfo &RFI,
1534                                Value *ReplVal = nullptr) {
1535     auto *UV = RFI.getUseVector(F);
1536     if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1537       return false;
1538 
1539     LLVM_DEBUG(
1540         dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1541                << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1542 
1543     assert((!ReplVal || (isa<Argument>(ReplVal) &&
1544                          cast<Argument>(ReplVal)->getParent() == &F)) &&
1545            "Unexpected replacement value!");
1546 
1547     // TODO: Use dominance to find a good position instead.
1548     auto CanBeMoved = [this](CallBase &CB) {
1549       unsigned NumArgs = CB.getNumArgOperands();
1550       if (NumArgs == 0)
1551         return true;
1552       if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1553         return false;
1554       for (unsigned u = 1; u < NumArgs; ++u)
1555         if (isa<Instruction>(CB.getArgOperand(u)))
1556           return false;
1557       return true;
1558     };
1559 
1560     if (!ReplVal) {
1561       for (Use *U : *UV)
1562         if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1563           if (!CanBeMoved(*CI))
1564             continue;
1565 
1566           CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
1567           ReplVal = CI;
1568           break;
1569         }
1570       if (!ReplVal)
1571         return false;
1572     }
1573 
1574     // If we use a call as a replacement value we need to make sure the ident is
1575     // valid at the new location. For now we just pick a global one, either
1576     // existing and used by one of the calls, or created from scratch.
1577     if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1578       if (CI->getNumArgOperands() > 0 &&
1579           CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1580         Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1581                                                       /* GlobalOnly */ true);
1582         CI->setArgOperand(0, Ident);
1583       }
1584     }
1585 
1586     bool Changed = false;
1587     auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1588       CallInst *CI = getCallIfRegularCall(U, &RFI);
1589       if (!CI || CI == ReplVal || &F != &Caller)
1590         return false;
1591       assert(CI->getCaller() == &F && "Unexpected call!");
1592 
1593       auto Remark = [&](OptimizationRemark OR) {
1594         return OR << "OpenMP runtime call "
1595                   << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1596       };
1597       if (CI->getDebugLoc())
1598         emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1599       else
1600         emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1601 
1602       CGUpdater.removeCallSite(*CI);
1603       CI->replaceAllUsesWith(ReplVal);
1604       CI->eraseFromParent();
1605       ++NumOpenMPRuntimeCallsDeduplicated;
1606       Changed = true;
1607       return true;
1608     };
1609     RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1610 
1611     return Changed;
1612   }
1613 
1614   /// Collect arguments that represent the global thread id in \p GTIdArgs.
1615   void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1616     // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1617     //       initialization. We could define an AbstractAttribute instead and
1618     //       run the Attributor here once it can be run as an SCC pass.
1619 
1620     // Helper to check the argument \p ArgNo at all call sites of \p F for
1621     // a GTId.
1622     auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1623       if (!F.hasLocalLinkage())
1624         return false;
1625       for (Use &U : F.uses()) {
1626         if (CallInst *CI = getCallIfRegularCall(U)) {
1627           Value *ArgOp = CI->getArgOperand(ArgNo);
1628           if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1629               getCallIfRegularCall(
1630                   *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1631             continue;
1632         }
1633         return false;
1634       }
1635       return true;
1636     };
1637 
1638     // Helper to identify uses of a GTId as GTId arguments.
1639     auto AddUserArgs = [&](Value &GTId) {
1640       for (Use &U : GTId.uses())
1641         if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1642           if (CI->isArgOperand(&U))
1643             if (Function *Callee = CI->getCalledFunction())
1644               if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1645                 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1646     };
1647 
1648     // The argument users of __kmpc_global_thread_num calls are GTIds.
1649     OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1650         OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1651 
1652     GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1653       if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1654         AddUserArgs(*CI);
1655       return false;
1656     });
1657 
1658     // Transitively search for more arguments by looking at the users of the
1659     // ones we know already. During the search the GTIdArgs vector is extended
1660     // so we cannot cache the size nor can we use a range based for.
1661     for (unsigned u = 0; u < GTIdArgs.size(); ++u)
1662       AddUserArgs(*GTIdArgs[u]);
1663   }
1664 
1665   /// Kernel (=GPU) optimizations and utility functions
1666   ///
1667   ///{{
1668 
1669   /// Check if \p F is a kernel, hence entry point for target offloading.
1670   bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
1671 
1672   /// Cache to remember the unique kernel for a function.
1673   DenseMap<Function *, Optional<Kernel>> UniqueKernelMap;
1674 
1675   /// Find the unique kernel that will execute \p F, if any.
1676   Kernel getUniqueKernelFor(Function &F);
1677 
1678   /// Find the unique kernel that will execute \p I, if any.
1679   Kernel getUniqueKernelFor(Instruction &I) {
1680     return getUniqueKernelFor(*I.getFunction());
1681   }
1682 
1683   /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1684   /// the cases we can avoid taking the address of a function.
1685   bool rewriteDeviceCodeStateMachine();
1686 
1687   ///
1688   ///}}
1689 
1690   /// Emit a remark generically
1691   ///
1692   /// This template function can be used to generically emit a remark. The
1693   /// RemarkKind should be one of the following:
1694   ///   - OptimizationRemark to indicate a successful optimization attempt
1695   ///   - OptimizationRemarkMissed to report a failed optimization attempt
1696   ///   - OptimizationRemarkAnalysis to provide additional information about an
1697   ///     optimization attempt
1698   ///
1699   /// The remark is built using a callback function provided by the caller that
1700   /// takes a RemarkKind as input and returns a RemarkKind.
1701   template <typename RemarkKind, typename RemarkCallBack>
1702   void emitRemark(Instruction *I, StringRef RemarkName,
1703                   RemarkCallBack &&RemarkCB) const {
1704     Function *F = I->getParent()->getParent();
1705     auto &ORE = OREGetter(F);
1706 
1707     if (RemarkName.startswith("OMP"))
1708       ORE.emit([&]() {
1709         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
1710                << " [" << RemarkName << "]";
1711       });
1712     else
1713       ORE.emit(
1714           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
1715   }
1716 
1717   /// Emit a remark on a function.
1718   template <typename RemarkKind, typename RemarkCallBack>
1719   void emitRemark(Function *F, StringRef RemarkName,
1720                   RemarkCallBack &&RemarkCB) const {
1721     auto &ORE = OREGetter(F);
1722 
1723     if (RemarkName.startswith("OMP"))
1724       ORE.emit([&]() {
1725         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
1726                << " [" << RemarkName << "]";
1727       });
1728     else
1729       ORE.emit(
1730           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
1731   }
1732 
1733   /// The underlying module.
1734   Module &M;
1735 
1736   /// The SCC we are operating on.
1737   SmallVectorImpl<Function *> &SCC;
1738 
1739   /// Callback to update the call graph, the first argument is a removed call,
1740   /// the second an optional replacement call.
1741   CallGraphUpdater &CGUpdater;
1742 
1743   /// Callback to get an OptimizationRemarkEmitter from a Function *
1744   OptimizationRemarkGetter OREGetter;
1745 
1746   /// OpenMP-specific information cache. Also Used for Attributor runs.
1747   OMPInformationCache &OMPInfoCache;
1748 
1749   /// Attributor instance.
1750   Attributor &A;
1751 
1752   /// Helper function to run Attributor on SCC.
1753   bool runAttributor(bool IsModulePass) {
1754     if (SCC.empty())
1755       return false;
1756 
1757     registerAAs(IsModulePass);
1758 
1759     ChangeStatus Changed = A.run();
1760 
1761     LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
1762                       << " functions, result: " << Changed << ".\n");
1763 
1764     return Changed == ChangeStatus::CHANGED;
1765   }
1766 
1767   /// Populate the Attributor with abstract attribute opportunities in the
1768   /// function.
1769   void registerAAs(bool IsModulePass);
1770 };
1771 
1772 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
1773   if (!OMPInfoCache.ModuleSlice.count(&F))
1774     return nullptr;
1775 
1776   // Use a scope to keep the lifetime of the CachedKernel short.
1777   {
1778     Optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
1779     if (CachedKernel)
1780       return *CachedKernel;
1781 
1782     // TODO: We should use an AA to create an (optimistic and callback
1783     //       call-aware) call graph. For now we stick to simple patterns that
1784     //       are less powerful, basically the worst fixpoint.
1785     if (isKernel(F)) {
1786       CachedKernel = Kernel(&F);
1787       return *CachedKernel;
1788     }
1789 
1790     CachedKernel = nullptr;
1791     if (!F.hasLocalLinkage()) {
1792 
1793       // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
1794       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1795         return ORA << "Potentially unknown OpenMP target region caller.";
1796       };
1797       emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
1798 
1799       return nullptr;
1800     }
1801   }
1802 
1803   auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
1804     if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
1805       // Allow use in equality comparisons.
1806       if (Cmp->isEquality())
1807         return getUniqueKernelFor(*Cmp);
1808       return nullptr;
1809     }
1810     if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
1811       // Allow direct calls.
1812       if (CB->isCallee(&U))
1813         return getUniqueKernelFor(*CB);
1814 
1815       OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1816           OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1817       // Allow the use in __kmpc_parallel_51 calls.
1818       if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
1819         return getUniqueKernelFor(*CB);
1820       return nullptr;
1821     }
1822     // Disallow every other use.
1823     return nullptr;
1824   };
1825 
1826   // TODO: In the future we want to track more than just a unique kernel.
1827   SmallPtrSet<Kernel, 2> PotentialKernels;
1828   OMPInformationCache::foreachUse(F, [&](const Use &U) {
1829     PotentialKernels.insert(GetUniqueKernelForUse(U));
1830   });
1831 
1832   Kernel K = nullptr;
1833   if (PotentialKernels.size() == 1)
1834     K = *PotentialKernels.begin();
1835 
1836   // Cache the result.
1837   UniqueKernelMap[&F] = K;
1838 
1839   return K;
1840 }
1841 
1842 bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
1843   OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1844       OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1845 
1846   bool Changed = false;
1847   if (!KernelParallelRFI)
1848     return Changed;
1849 
1850   for (Function *F : SCC) {
1851 
1852     // Check if the function is a use in a __kmpc_parallel_51 call at
1853     // all.
1854     bool UnknownUse = false;
1855     bool KernelParallelUse = false;
1856     unsigned NumDirectCalls = 0;
1857 
1858     SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
1859     OMPInformationCache::foreachUse(*F, [&](Use &U) {
1860       if (auto *CB = dyn_cast<CallBase>(U.getUser()))
1861         if (CB->isCallee(&U)) {
1862           ++NumDirectCalls;
1863           return;
1864         }
1865 
1866       if (isa<ICmpInst>(U.getUser())) {
1867         ToBeReplacedStateMachineUses.push_back(&U);
1868         return;
1869       }
1870 
1871       // Find wrapper functions that represent parallel kernels.
1872       CallInst *CI =
1873           OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
1874       const unsigned int WrapperFunctionArgNo = 6;
1875       if (!KernelParallelUse && CI &&
1876           CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
1877         KernelParallelUse = true;
1878         ToBeReplacedStateMachineUses.push_back(&U);
1879         return;
1880       }
1881       UnknownUse = true;
1882     });
1883 
1884     // Do not emit a remark if we haven't seen a __kmpc_parallel_51
1885     // use.
1886     if (!KernelParallelUse)
1887       continue;
1888 
1889     // If this ever hits, we should investigate.
1890     // TODO: Checking the number of uses is not a necessary restriction and
1891     // should be lifted.
1892     if (UnknownUse || NumDirectCalls != 1 ||
1893         ToBeReplacedStateMachineUses.size() > 2) {
1894       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1895         return ORA << "Parallel region is used in "
1896                    << (UnknownUse ? "unknown" : "unexpected")
1897                    << " ways. Will not attempt to rewrite the state machine.";
1898       };
1899       emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
1900       continue;
1901     }
1902 
1903     // Even if we have __kmpc_parallel_51 calls, we (for now) give
1904     // up if the function is not called from a unique kernel.
1905     Kernel K = getUniqueKernelFor(*F);
1906     if (!K) {
1907       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1908         return ORA << "Parallel region is not called from a unique kernel. "
1909                       "Will not attempt to rewrite the state machine.";
1910       };
1911       emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
1912       continue;
1913     }
1914 
1915     // We now know F is a parallel body function called only from the kernel K.
1916     // We also identified the state machine uses in which we replace the
1917     // function pointer by a new global symbol for identification purposes. This
1918     // ensures only direct calls to the function are left.
1919 
1920     Module &M = *F->getParent();
1921     Type *Int8Ty = Type::getInt8Ty(M.getContext());
1922 
1923     auto *ID = new GlobalVariable(
1924         M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
1925         UndefValue::get(Int8Ty), F->getName() + ".ID");
1926 
1927     for (Use *U : ToBeReplacedStateMachineUses)
1928       U->set(ConstantExpr::getBitCast(ID, U->get()->getType()));
1929 
1930     ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
1931 
1932     Changed = true;
1933   }
1934 
1935   return Changed;
1936 }
1937 
1938 /// Abstract Attribute for tracking ICV values.
1939 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
1940   using Base = StateWrapper<BooleanState, AbstractAttribute>;
1941   AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
1942 
1943   void initialize(Attributor &A) override {
1944     Function *F = getAnchorScope();
1945     if (!F || !A.isFunctionIPOAmendable(*F))
1946       indicatePessimisticFixpoint();
1947   }
1948 
1949   /// Returns true if value is assumed to be tracked.
1950   bool isAssumedTracked() const { return getAssumed(); }
1951 
1952   /// Returns true if value is known to be tracked.
1953   bool isKnownTracked() const { return getAssumed(); }
1954 
1955   /// Create an abstract attribute biew for the position \p IRP.
1956   static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
1957 
1958   /// Return the value with which \p I can be replaced for specific \p ICV.
1959   virtual Optional<Value *> getReplacementValue(InternalControlVar ICV,
1960                                                 const Instruction *I,
1961                                                 Attributor &A) const {
1962     return None;
1963   }
1964 
1965   /// Return an assumed unique ICV value if a single candidate is found. If
1966   /// there cannot be one, return a nullptr. If it is not clear yet, return the
1967   /// Optional::NoneType.
1968   virtual Optional<Value *>
1969   getUniqueReplacementValue(InternalControlVar ICV) const = 0;
1970 
1971   // Currently only nthreads is being tracked.
1972   // this array will only grow with time.
1973   InternalControlVar TrackableICVs[1] = {ICV_nthreads};
1974 
1975   /// See AbstractAttribute::getName()
1976   const std::string getName() const override { return "AAICVTracker"; }
1977 
1978   /// See AbstractAttribute::getIdAddr()
1979   const char *getIdAddr() const override { return &ID; }
1980 
1981   /// This function should return true if the type of the \p AA is AAICVTracker
1982   static bool classof(const AbstractAttribute *AA) {
1983     return (AA->getIdAddr() == &ID);
1984   }
1985 
1986   static const char ID;
1987 };
1988 
1989 struct AAICVTrackerFunction : public AAICVTracker {
1990   AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
1991       : AAICVTracker(IRP, A) {}
1992 
1993   // FIXME: come up with better string.
1994   const std::string getAsStr() const override { return "ICVTrackerFunction"; }
1995 
1996   // FIXME: come up with some stats.
1997   void trackStatistics() const override {}
1998 
1999   /// We don't manifest anything for this AA.
2000   ChangeStatus manifest(Attributor &A) override {
2001     return ChangeStatus::UNCHANGED;
2002   }
2003 
2004   // Map of ICV to their values at specific program point.
2005   EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
2006                   InternalControlVar::ICV___last>
2007       ICVReplacementValuesMap;
2008 
2009   ChangeStatus updateImpl(Attributor &A) override {
2010     ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2011 
2012     Function *F = getAnchorScope();
2013 
2014     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2015 
2016     for (InternalControlVar ICV : TrackableICVs) {
2017       auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2018 
2019       auto &ValuesMap = ICVReplacementValuesMap[ICV];
2020       auto TrackValues = [&](Use &U, Function &) {
2021         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2022         if (!CI)
2023           return false;
2024 
2025         // FIXME: handle setters with more that 1 arguments.
2026         /// Track new value.
2027         if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2028           HasChanged = ChangeStatus::CHANGED;
2029 
2030         return false;
2031       };
2032 
2033       auto CallCheck = [&](Instruction &I) {
2034         Optional<Value *> ReplVal = getValueForCall(A, &I, ICV);
2035         if (ReplVal.hasValue() &&
2036             ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2037           HasChanged = ChangeStatus::CHANGED;
2038 
2039         return true;
2040       };
2041 
2042       // Track all changes of an ICV.
2043       SetterRFI.foreachUse(TrackValues, F);
2044 
2045       bool UsedAssumedInformation = false;
2046       A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2047                                 UsedAssumedInformation,
2048                                 /* CheckBBLivenessOnly */ true);
2049 
2050       /// TODO: Figure out a way to avoid adding entry in
2051       /// ICVReplacementValuesMap
2052       Instruction *Entry = &F->getEntryBlock().front();
2053       if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2054         ValuesMap.insert(std::make_pair(Entry, nullptr));
2055     }
2056 
2057     return HasChanged;
2058   }
2059 
2060   /// Hepler to check if \p I is a call and get the value for it if it is
2061   /// unique.
2062   Optional<Value *> getValueForCall(Attributor &A, const Instruction *I,
2063                                     InternalControlVar &ICV) const {
2064 
2065     const auto *CB = dyn_cast<CallBase>(I);
2066     if (!CB || CB->hasFnAttr("no_openmp") ||
2067         CB->hasFnAttr("no_openmp_routines"))
2068       return None;
2069 
2070     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2071     auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2072     auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2073     Function *CalledFunction = CB->getCalledFunction();
2074 
2075     // Indirect call, assume ICV changes.
2076     if (CalledFunction == nullptr)
2077       return nullptr;
2078     if (CalledFunction == GetterRFI.Declaration)
2079       return None;
2080     if (CalledFunction == SetterRFI.Declaration) {
2081       if (ICVReplacementValuesMap[ICV].count(I))
2082         return ICVReplacementValuesMap[ICV].lookup(I);
2083 
2084       return nullptr;
2085     }
2086 
2087     // Since we don't know, assume it changes the ICV.
2088     if (CalledFunction->isDeclaration())
2089       return nullptr;
2090 
2091     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2092         *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2093 
2094     if (ICVTrackingAA.isAssumedTracked())
2095       return ICVTrackingAA.getUniqueReplacementValue(ICV);
2096 
2097     // If we don't know, assume it changes.
2098     return nullptr;
2099   }
2100 
2101   // We don't check unique value for a function, so return None.
2102   Optional<Value *>
2103   getUniqueReplacementValue(InternalControlVar ICV) const override {
2104     return None;
2105   }
2106 
2107   /// Return the value with which \p I can be replaced for specific \p ICV.
2108   Optional<Value *> getReplacementValue(InternalControlVar ICV,
2109                                         const Instruction *I,
2110                                         Attributor &A) const override {
2111     const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2112     if (ValuesMap.count(I))
2113       return ValuesMap.lookup(I);
2114 
2115     SmallVector<const Instruction *, 16> Worklist;
2116     SmallPtrSet<const Instruction *, 16> Visited;
2117     Worklist.push_back(I);
2118 
2119     Optional<Value *> ReplVal;
2120 
2121     while (!Worklist.empty()) {
2122       const Instruction *CurrInst = Worklist.pop_back_val();
2123       if (!Visited.insert(CurrInst).second)
2124         continue;
2125 
2126       const BasicBlock *CurrBB = CurrInst->getParent();
2127 
2128       // Go up and look for all potential setters/calls that might change the
2129       // ICV.
2130       while ((CurrInst = CurrInst->getPrevNode())) {
2131         if (ValuesMap.count(CurrInst)) {
2132           Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2133           // Unknown value, track new.
2134           if (!ReplVal.hasValue()) {
2135             ReplVal = NewReplVal;
2136             break;
2137           }
2138 
2139           // If we found a new value, we can't know the icv value anymore.
2140           if (NewReplVal.hasValue())
2141             if (ReplVal != NewReplVal)
2142               return nullptr;
2143 
2144           break;
2145         }
2146 
2147         Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV);
2148         if (!NewReplVal.hasValue())
2149           continue;
2150 
2151         // Unknown value, track new.
2152         if (!ReplVal.hasValue()) {
2153           ReplVal = NewReplVal;
2154           break;
2155         }
2156 
2157         // if (NewReplVal.hasValue())
2158         // We found a new value, we can't know the icv value anymore.
2159         if (ReplVal != NewReplVal)
2160           return nullptr;
2161       }
2162 
2163       // If we are in the same BB and we have a value, we are done.
2164       if (CurrBB == I->getParent() && ReplVal.hasValue())
2165         return ReplVal;
2166 
2167       // Go through all predecessors and add terminators for analysis.
2168       for (const BasicBlock *Pred : predecessors(CurrBB))
2169         if (const Instruction *Terminator = Pred->getTerminator())
2170           Worklist.push_back(Terminator);
2171     }
2172 
2173     return ReplVal;
2174   }
2175 };
2176 
2177 struct AAICVTrackerFunctionReturned : AAICVTracker {
2178   AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2179       : AAICVTracker(IRP, A) {}
2180 
2181   // FIXME: come up with better string.
2182   const std::string getAsStr() const override {
2183     return "ICVTrackerFunctionReturned";
2184   }
2185 
2186   // FIXME: come up with some stats.
2187   void trackStatistics() const override {}
2188 
2189   /// We don't manifest anything for this AA.
2190   ChangeStatus manifest(Attributor &A) override {
2191     return ChangeStatus::UNCHANGED;
2192   }
2193 
2194   // Map of ICV to their values at specific program point.
2195   EnumeratedArray<Optional<Value *>, InternalControlVar,
2196                   InternalControlVar::ICV___last>
2197       ICVReplacementValuesMap;
2198 
2199   /// Return the value with which \p I can be replaced for specific \p ICV.
2200   Optional<Value *>
2201   getUniqueReplacementValue(InternalControlVar ICV) const override {
2202     return ICVReplacementValuesMap[ICV];
2203   }
2204 
2205   ChangeStatus updateImpl(Attributor &A) override {
2206     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2207     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2208         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2209 
2210     if (!ICVTrackingAA.isAssumedTracked())
2211       return indicatePessimisticFixpoint();
2212 
2213     for (InternalControlVar ICV : TrackableICVs) {
2214       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2215       Optional<Value *> UniqueICVValue;
2216 
2217       auto CheckReturnInst = [&](Instruction &I) {
2218         Optional<Value *> NewReplVal =
2219             ICVTrackingAA.getReplacementValue(ICV, &I, A);
2220 
2221         // If we found a second ICV value there is no unique returned value.
2222         if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal)
2223           return false;
2224 
2225         UniqueICVValue = NewReplVal;
2226 
2227         return true;
2228       };
2229 
2230       bool UsedAssumedInformation = false;
2231       if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2232                                      UsedAssumedInformation,
2233                                      /* CheckBBLivenessOnly */ true))
2234         UniqueICVValue = nullptr;
2235 
2236       if (UniqueICVValue == ReplVal)
2237         continue;
2238 
2239       ReplVal = UniqueICVValue;
2240       Changed = ChangeStatus::CHANGED;
2241     }
2242 
2243     return Changed;
2244   }
2245 };
2246 
2247 struct AAICVTrackerCallSite : AAICVTracker {
2248   AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2249       : AAICVTracker(IRP, A) {}
2250 
2251   void initialize(Attributor &A) override {
2252     Function *F = getAnchorScope();
2253     if (!F || !A.isFunctionIPOAmendable(*F))
2254       indicatePessimisticFixpoint();
2255 
2256     // We only initialize this AA for getters, so we need to know which ICV it
2257     // gets.
2258     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2259     for (InternalControlVar ICV : TrackableICVs) {
2260       auto ICVInfo = OMPInfoCache.ICVs[ICV];
2261       auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2262       if (Getter.Declaration == getAssociatedFunction()) {
2263         AssociatedICV = ICVInfo.Kind;
2264         return;
2265       }
2266     }
2267 
2268     /// Unknown ICV.
2269     indicatePessimisticFixpoint();
2270   }
2271 
2272   ChangeStatus manifest(Attributor &A) override {
2273     if (!ReplVal.hasValue() || !ReplVal.getValue())
2274       return ChangeStatus::UNCHANGED;
2275 
2276     A.changeValueAfterManifest(*getCtxI(), **ReplVal);
2277     A.deleteAfterManifest(*getCtxI());
2278 
2279     return ChangeStatus::CHANGED;
2280   }
2281 
2282   // FIXME: come up with better string.
2283   const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
2284 
2285   // FIXME: come up with some stats.
2286   void trackStatistics() const override {}
2287 
2288   InternalControlVar AssociatedICV;
2289   Optional<Value *> ReplVal;
2290 
2291   ChangeStatus updateImpl(Attributor &A) override {
2292     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2293         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2294 
2295     // We don't have any information, so we assume it changes the ICV.
2296     if (!ICVTrackingAA.isAssumedTracked())
2297       return indicatePessimisticFixpoint();
2298 
2299     Optional<Value *> NewReplVal =
2300         ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
2301 
2302     if (ReplVal == NewReplVal)
2303       return ChangeStatus::UNCHANGED;
2304 
2305     ReplVal = NewReplVal;
2306     return ChangeStatus::CHANGED;
2307   }
2308 
2309   // Return the value with which associated value can be replaced for specific
2310   // \p ICV.
2311   Optional<Value *>
2312   getUniqueReplacementValue(InternalControlVar ICV) const override {
2313     return ReplVal;
2314   }
2315 };
2316 
2317 struct AAICVTrackerCallSiteReturned : AAICVTracker {
2318   AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2319       : AAICVTracker(IRP, A) {}
2320 
2321   // FIXME: come up with better string.
2322   const std::string getAsStr() const override {
2323     return "ICVTrackerCallSiteReturned";
2324   }
2325 
2326   // FIXME: come up with some stats.
2327   void trackStatistics() const override {}
2328 
2329   /// We don't manifest anything for this AA.
2330   ChangeStatus manifest(Attributor &A) override {
2331     return ChangeStatus::UNCHANGED;
2332   }
2333 
2334   // Map of ICV to their values at specific program point.
2335   EnumeratedArray<Optional<Value *>, InternalControlVar,
2336                   InternalControlVar::ICV___last>
2337       ICVReplacementValuesMap;
2338 
2339   /// Return the value with which associated value can be replaced for specific
2340   /// \p ICV.
2341   Optional<Value *>
2342   getUniqueReplacementValue(InternalControlVar ICV) const override {
2343     return ICVReplacementValuesMap[ICV];
2344   }
2345 
2346   ChangeStatus updateImpl(Attributor &A) override {
2347     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2348     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2349         *this, IRPosition::returned(*getAssociatedFunction()),
2350         DepClassTy::REQUIRED);
2351 
2352     // We don't have any information, so we assume it changes the ICV.
2353     if (!ICVTrackingAA.isAssumedTracked())
2354       return indicatePessimisticFixpoint();
2355 
2356     for (InternalControlVar ICV : TrackableICVs) {
2357       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2358       Optional<Value *> NewReplVal =
2359           ICVTrackingAA.getUniqueReplacementValue(ICV);
2360 
2361       if (ReplVal == NewReplVal)
2362         continue;
2363 
2364       ReplVal = NewReplVal;
2365       Changed = ChangeStatus::CHANGED;
2366     }
2367     return Changed;
2368   }
2369 };
2370 
2371 struct AAExecutionDomainFunction : public AAExecutionDomain {
2372   AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2373       : AAExecutionDomain(IRP, A) {}
2374 
2375   const std::string getAsStr() const override {
2376     return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) +
2377            "/" + std::to_string(NumBBs) + " BBs thread 0 only.";
2378   }
2379 
2380   /// See AbstractAttribute::trackStatistics().
2381   void trackStatistics() const override {}
2382 
2383   void initialize(Attributor &A) override {
2384     Function *F = getAnchorScope();
2385     for (const auto &BB : *F)
2386       SingleThreadedBBs.insert(&BB);
2387     NumBBs = SingleThreadedBBs.size();
2388   }
2389 
2390   ChangeStatus manifest(Attributor &A) override {
2391     LLVM_DEBUG({
2392       for (const BasicBlock *BB : SingleThreadedBBs)
2393         dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2394                << BB->getName() << " is executed by a single thread.\n";
2395     });
2396     return ChangeStatus::UNCHANGED;
2397   }
2398 
2399   ChangeStatus updateImpl(Attributor &A) override;
2400 
2401   /// Check if an instruction is executed by a single thread.
2402   bool isExecutedByInitialThreadOnly(const Instruction &I) const override {
2403     return isExecutedByInitialThreadOnly(*I.getParent());
2404   }
2405 
2406   bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2407     return isValidState() && SingleThreadedBBs.contains(&BB);
2408   }
2409 
2410   /// Set of basic blocks that are executed by a single thread.
2411   DenseSet<const BasicBlock *> SingleThreadedBBs;
2412 
2413   /// Total number of basic blocks in this function.
2414   long unsigned NumBBs;
2415 };
2416 
2417 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
2418   Function *F = getAnchorScope();
2419   ReversePostOrderTraversal<Function *> RPOT(F);
2420   auto NumSingleThreadedBBs = SingleThreadedBBs.size();
2421 
2422   bool AllCallSitesKnown;
2423   auto PredForCallSite = [&](AbstractCallSite ACS) {
2424     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
2425         *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
2426         DepClassTy::REQUIRED);
2427     return ACS.isDirectCall() &&
2428            ExecutionDomainAA.isExecutedByInitialThreadOnly(
2429                *ACS.getInstruction());
2430   };
2431 
2432   if (!A.checkForAllCallSites(PredForCallSite, *this,
2433                               /* RequiresAllCallSites */ true,
2434                               AllCallSitesKnown))
2435     SingleThreadedBBs.erase(&F->getEntryBlock());
2436 
2437   auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2438   auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2439 
2440   // Check if the edge into the successor block compares the __kmpc_target_init
2441   // result with -1. If we are in non-SPMD-mode that signals only the main
2442   // thread will execute the edge.
2443   auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
2444     if (!Edge || !Edge->isConditional())
2445       return false;
2446     if (Edge->getSuccessor(0) != SuccessorBB)
2447       return false;
2448 
2449     auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2450     if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2451       return false;
2452 
2453     ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2454     if (!C)
2455       return false;
2456 
2457     // Match:  -1 == __kmpc_target_init (for non-SPMD kernels only!)
2458     if (C->isAllOnesValue()) {
2459       auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2460       if (!CB || CB->getCalledFunction() != RFI.Declaration)
2461         return false;
2462       const int InitIsSPMDArgNo = 1;
2463       auto *IsSPMDModeCI =
2464           dyn_cast<ConstantInt>(CB->getOperand(InitIsSPMDArgNo));
2465       return IsSPMDModeCI && IsSPMDModeCI->isZero();
2466     }
2467 
2468     return false;
2469   };
2470 
2471   // Merge all the predecessor states into the current basic block. A basic
2472   // block is executed by a single thread if all of its predecessors are.
2473   auto MergePredecessorStates = [&](BasicBlock *BB) {
2474     if (pred_begin(BB) == pred_end(BB))
2475       return SingleThreadedBBs.contains(BB);
2476 
2477     bool IsInitialThread = true;
2478     for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB);
2479          PredBB != PredEndBB; ++PredBB) {
2480       if (!IsInitialThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()),
2481                                BB))
2482         IsInitialThread &= SingleThreadedBBs.contains(*PredBB);
2483     }
2484 
2485     return IsInitialThread;
2486   };
2487 
2488   for (auto *BB : RPOT) {
2489     if (!MergePredecessorStates(BB))
2490       SingleThreadedBBs.erase(BB);
2491   }
2492 
2493   return (NumSingleThreadedBBs == SingleThreadedBBs.size())
2494              ? ChangeStatus::UNCHANGED
2495              : ChangeStatus::CHANGED;
2496 }
2497 
2498 /// Try to replace memory allocation calls called by a single thread with a
2499 /// static buffer of shared memory.
2500 struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
2501   using Base = StateWrapper<BooleanState, AbstractAttribute>;
2502   AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2503 
2504   /// Create an abstract attribute view for the position \p IRP.
2505   static AAHeapToShared &createForPosition(const IRPosition &IRP,
2506                                            Attributor &A);
2507 
2508   /// See AbstractAttribute::getName().
2509   const std::string getName() const override { return "AAHeapToShared"; }
2510 
2511   /// See AbstractAttribute::getIdAddr().
2512   const char *getIdAddr() const override { return &ID; }
2513 
2514   /// This function should return true if the type of the \p AA is
2515   /// AAHeapToShared.
2516   static bool classof(const AbstractAttribute *AA) {
2517     return (AA->getIdAddr() == &ID);
2518   }
2519 
2520   /// Unique ID (due to the unique address)
2521   static const char ID;
2522 };
2523 
2524 struct AAHeapToSharedFunction : public AAHeapToShared {
2525   AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
2526       : AAHeapToShared(IRP, A) {}
2527 
2528   const std::string getAsStr() const override {
2529     return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
2530            " malloc calls eligible.";
2531   }
2532 
2533   /// See AbstractAttribute::trackStatistics().
2534   void trackStatistics() const override {}
2535 
2536   void initialize(Attributor &A) override {
2537     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2538     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2539 
2540     for (User *U : RFI.Declaration->users())
2541       if (CallBase *CB = dyn_cast<CallBase>(U))
2542         MallocCalls.insert(CB);
2543   }
2544 
2545   ChangeStatus manifest(Attributor &A) override {
2546     if (MallocCalls.empty())
2547       return ChangeStatus::UNCHANGED;
2548 
2549     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2550     auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2551 
2552     Function *F = getAnchorScope();
2553     auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
2554                                             DepClassTy::OPTIONAL);
2555 
2556     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2557     for (CallBase *CB : MallocCalls) {
2558       // Skip replacing this if HeapToStack has already claimed it.
2559       if (HS && HS->isAssumedHeapToStack(*CB))
2560         continue;
2561 
2562       // Find the unique free call to remove it.
2563       SmallVector<CallBase *, 4> FreeCalls;
2564       for (auto *U : CB->users()) {
2565         CallBase *C = dyn_cast<CallBase>(U);
2566         if (C && C->getCalledFunction() == FreeCall.Declaration)
2567           FreeCalls.push_back(C);
2568       }
2569       if (FreeCalls.size() != 1)
2570         continue;
2571 
2572       ConstantInt *AllocSize = dyn_cast<ConstantInt>(CB->getArgOperand(0));
2573 
2574       LLVM_DEBUG(dbgs() << TAG << "Replace globalization call in "
2575                         << CB->getCaller()->getName() << " with "
2576                         << AllocSize->getZExtValue()
2577                         << " bytes of shared memory\n");
2578 
2579       // Create a new shared memory buffer of the same size as the allocation
2580       // and replace all the uses of the original allocation with it.
2581       Module *M = CB->getModule();
2582       Type *Int8Ty = Type::getInt8Ty(M->getContext());
2583       Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
2584       auto *SharedMem = new GlobalVariable(
2585           *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
2586           UndefValue::get(Int8ArrTy), CB->getName(), nullptr,
2587           GlobalValue::NotThreadLocal,
2588           static_cast<unsigned>(AddressSpace::Shared));
2589       auto *NewBuffer =
2590           ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
2591 
2592       auto Remark = [&](OptimizationRemark OR) {
2593         return OR << "Replaced globalized variable with "
2594                   << ore::NV("SharedMemory", AllocSize->getZExtValue())
2595                   << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ")
2596                   << "of shared memory.";
2597       };
2598       A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
2599 
2600       SharedMem->setAlignment(MaybeAlign(32));
2601 
2602       A.changeValueAfterManifest(*CB, *NewBuffer);
2603       A.deleteAfterManifest(*CB);
2604       A.deleteAfterManifest(*FreeCalls.front());
2605 
2606       NumBytesMovedToSharedMemory += AllocSize->getZExtValue();
2607       Changed = ChangeStatus::CHANGED;
2608     }
2609 
2610     return Changed;
2611   }
2612 
2613   ChangeStatus updateImpl(Attributor &A) override {
2614     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2615     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2616     Function *F = getAnchorScope();
2617 
2618     auto NumMallocCalls = MallocCalls.size();
2619 
2620     // Only consider malloc calls executed by a single thread with a constant.
2621     for (User *U : RFI.Declaration->users()) {
2622       const auto &ED = A.getAAFor<AAExecutionDomain>(
2623           *this, IRPosition::function(*F), DepClassTy::REQUIRED);
2624       if (CallBase *CB = dyn_cast<CallBase>(U))
2625         if (!dyn_cast<ConstantInt>(CB->getArgOperand(0)) ||
2626             !ED.isExecutedByInitialThreadOnly(*CB))
2627           MallocCalls.erase(CB);
2628     }
2629 
2630     if (NumMallocCalls != MallocCalls.size())
2631       return ChangeStatus::CHANGED;
2632 
2633     return ChangeStatus::UNCHANGED;
2634   }
2635 
2636   /// Collection of all malloc calls in a function.
2637   SmallPtrSet<CallBase *, 4> MallocCalls;
2638 };
2639 
2640 struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
2641   using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
2642   AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2643 
2644   /// Statistics are tracked as part of manifest for now.
2645   void trackStatistics() const override {}
2646 
2647   /// See AbstractAttribute::getAsStr()
2648   const std::string getAsStr() const override {
2649     if (!isValidState())
2650       return "<invalid>";
2651     return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
2652                                                             : "generic") +
2653            std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
2654                                                                : "") +
2655            std::string(" #PRs: ") +
2656            std::to_string(ReachedKnownParallelRegions.size()) +
2657            ", #Unknown PRs: " +
2658            std::to_string(ReachedUnknownParallelRegions.size());
2659   }
2660 
2661   /// Create an abstract attribute biew for the position \p IRP.
2662   static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
2663 
2664   /// See AbstractAttribute::getName()
2665   const std::string getName() const override { return "AAKernelInfo"; }
2666 
2667   /// See AbstractAttribute::getIdAddr()
2668   const char *getIdAddr() const override { return &ID; }
2669 
2670   /// This function should return true if the type of the \p AA is AAKernelInfo
2671   static bool classof(const AbstractAttribute *AA) {
2672     return (AA->getIdAddr() == &ID);
2673   }
2674 
2675   static const char ID;
2676 };
2677 
2678 /// The function kernel info abstract attribute, basically, what can we say
2679 /// about a function with regards to the KernelInfoState.
2680 struct AAKernelInfoFunction : AAKernelInfo {
2681   AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
2682       : AAKernelInfo(IRP, A) {}
2683 
2684   /// See AbstractAttribute::initialize(...).
2685   void initialize(Attributor &A) override {
2686     // This is a high-level transform that might change the constant arguments
2687     // of the init and dinit calls. We need to tell the Attributor about this
2688     // to avoid other parts using the current constant value for simpliication.
2689     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2690 
2691     Function *Fn = getAnchorScope();
2692     if (!OMPInfoCache.Kernels.count(Fn))
2693       return;
2694 
2695     // Add itself to the reaching kernel and set IsKernelEntry.
2696     ReachingKernelEntries.insert(Fn);
2697     IsKernelEntry = true;
2698 
2699     OMPInformationCache::RuntimeFunctionInfo &InitRFI =
2700         OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2701     OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
2702         OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
2703 
2704     // For kernels we perform more initialization work, first we find the init
2705     // and deinit calls.
2706     auto StoreCallBase = [](Use &U,
2707                             OMPInformationCache::RuntimeFunctionInfo &RFI,
2708                             CallBase *&Storage) {
2709       CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
2710       assert(CB &&
2711              "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
2712       assert(!Storage &&
2713              "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
2714       Storage = CB;
2715       return false;
2716     };
2717     InitRFI.foreachUse(
2718         [&](Use &U, Function &) {
2719           StoreCallBase(U, InitRFI, KernelInitCB);
2720           return false;
2721         },
2722         Fn);
2723     DeinitRFI.foreachUse(
2724         [&](Use &U, Function &) {
2725           StoreCallBase(U, DeinitRFI, KernelDeinitCB);
2726           return false;
2727         },
2728         Fn);
2729 
2730     assert((KernelInitCB && KernelDeinitCB) &&
2731            "Kernel without __kmpc_target_init or __kmpc_target_deinit!");
2732 
2733     // For kernels we might need to initialize/finalize the IsSPMD state and
2734     // we need to register a simplification callback so that the Attributor
2735     // knows the constant arguments to __kmpc_target_init and
2736     // __kmpc_target_deinit might actually change.
2737 
2738     Attributor::SimplifictionCallbackTy StateMachineSimplifyCB =
2739         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2740             bool &UsedAssumedInformation) -> Optional<Value *> {
2741       // IRP represents the "use generic state machine" argument of an
2742       // __kmpc_target_init call. We will answer this one with the internal
2743       // state. As long as we are not in an invalid state, we will create a
2744       // custom state machine so the value should be a `i1 false`. If we are
2745       // in an invalid state, we won't change the value that is in the IR.
2746       if (!isValidState())
2747         return nullptr;
2748       if (AA)
2749         A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2750       UsedAssumedInformation = !isAtFixpoint();
2751       auto *FalseVal =
2752           ConstantInt::getBool(IRP.getAnchorValue().getContext(), 0);
2753       return FalseVal;
2754     };
2755 
2756     Attributor::SimplifictionCallbackTy IsSPMDModeSimplifyCB =
2757         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2758             bool &UsedAssumedInformation) -> Optional<Value *> {
2759       // IRP represents the "SPMDCompatibilityTracker" argument of an
2760       // __kmpc_target_init or
2761       // __kmpc_target_deinit call. We will answer this one with the internal
2762       // state.
2763       if (!SPMDCompatibilityTracker.isValidState())
2764         return nullptr;
2765       if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2766         if (AA)
2767           A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2768         UsedAssumedInformation = true;
2769       } else {
2770         UsedAssumedInformation = false;
2771       }
2772       auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
2773                                        SPMDCompatibilityTracker.isAssumed());
2774       return Val;
2775     };
2776 
2777     Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB =
2778         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2779             bool &UsedAssumedInformation) -> Optional<Value *> {
2780       // IRP represents the "RequiresFullRuntime" argument of an
2781       // __kmpc_target_init or __kmpc_target_deinit call. We will answer this
2782       // one with the internal state of the SPMDCompatibilityTracker, so if
2783       // generic then true, if SPMD then false.
2784       if (!SPMDCompatibilityTracker.isValidState())
2785         return nullptr;
2786       if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2787         if (AA)
2788           A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2789         UsedAssumedInformation = true;
2790       } else {
2791         UsedAssumedInformation = false;
2792       }
2793       auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
2794                                        !SPMDCompatibilityTracker.isAssumed());
2795       return Val;
2796     };
2797 
2798     constexpr const int InitIsSPMDArgNo = 1;
2799     constexpr const int DeinitIsSPMDArgNo = 1;
2800     constexpr const int InitUseStateMachineArgNo = 2;
2801     constexpr const int InitRequiresFullRuntimeArgNo = 3;
2802     constexpr const int DeinitRequiresFullRuntimeArgNo = 2;
2803     A.registerSimplificationCallback(
2804         IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo),
2805         StateMachineSimplifyCB);
2806     A.registerSimplificationCallback(
2807         IRPosition::callsite_argument(*KernelInitCB, InitIsSPMDArgNo),
2808         IsSPMDModeSimplifyCB);
2809     A.registerSimplificationCallback(
2810         IRPosition::callsite_argument(*KernelDeinitCB, DeinitIsSPMDArgNo),
2811         IsSPMDModeSimplifyCB);
2812     A.registerSimplificationCallback(
2813         IRPosition::callsite_argument(*KernelInitCB,
2814                                       InitRequiresFullRuntimeArgNo),
2815         IsGenericModeSimplifyCB);
2816     A.registerSimplificationCallback(
2817         IRPosition::callsite_argument(*KernelDeinitCB,
2818                                       DeinitRequiresFullRuntimeArgNo),
2819         IsGenericModeSimplifyCB);
2820 
2821     // Check if we know we are in SPMD-mode already.
2822     ConstantInt *IsSPMDArg =
2823         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo));
2824     if (IsSPMDArg && !IsSPMDArg->isZero())
2825       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
2826   }
2827 
2828   /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
2829   /// finished now.
2830   ChangeStatus manifest(Attributor &A) override {
2831     // If we are not looking at a kernel with __kmpc_target_init and
2832     // __kmpc_target_deinit call we cannot actually manifest the information.
2833     if (!KernelInitCB || !KernelDeinitCB)
2834       return ChangeStatus::UNCHANGED;
2835 
2836     // Known SPMD-mode kernels need no manifest changes.
2837     if (SPMDCompatibilityTracker.isKnown())
2838       return ChangeStatus::UNCHANGED;
2839 
2840     // If we can we change the execution mode to SPMD-mode otherwise we build a
2841     // custom state machine.
2842     if (!changeToSPMDMode(A))
2843       buildCustomStateMachine(A);
2844 
2845     return ChangeStatus::CHANGED;
2846   }
2847 
2848   bool changeToSPMDMode(Attributor &A) {
2849     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2850 
2851     if (!SPMDCompatibilityTracker.isAssumed()) {
2852       for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
2853         if (!NonCompatibleI)
2854           continue;
2855 
2856         // Skip diagnostics on calls to known OpenMP runtime functions for now.
2857         if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
2858           if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
2859             continue;
2860 
2861         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2862           ORA << "Value has potential side effects preventing SPMD-mode "
2863                  "execution";
2864           if (isa<CallBase>(NonCompatibleI)) {
2865             ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
2866                    "the called function to override";
2867           }
2868           return ORA << ".";
2869         };
2870         A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
2871                                                  Remark);
2872 
2873         LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
2874                           << *NonCompatibleI << "\n");
2875       }
2876 
2877       return false;
2878     }
2879 
2880     // Adjust the global exec mode flag that tells the runtime what mode this
2881     // kernel is executed in.
2882     Function *Kernel = getAnchorScope();
2883     GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable(
2884         (Kernel->getName() + "_exec_mode").str());
2885     assert(ExecMode && "Kernel without exec mode?");
2886     assert(ExecMode->getInitializer() &&
2887            ExecMode->getInitializer()->isOneValue() &&
2888            "Initially non-SPMD kernel has SPMD exec mode!");
2889     ExecMode->setInitializer(
2890         ConstantInt::get(ExecMode->getInitializer()->getType(), 0));
2891 
2892     // Next rewrite the init and deinit calls to indicate we use SPMD-mode now.
2893     const int InitIsSPMDArgNo = 1;
2894     const int DeinitIsSPMDArgNo = 1;
2895     const int InitUseStateMachineArgNo = 2;
2896     const int InitRequiresFullRuntimeArgNo = 3;
2897     const int DeinitRequiresFullRuntimeArgNo = 2;
2898 
2899     auto &Ctx = getAnchorValue().getContext();
2900     A.changeUseAfterManifest(KernelInitCB->getArgOperandUse(InitIsSPMDArgNo),
2901                              *ConstantInt::getBool(Ctx, 1));
2902     A.changeUseAfterManifest(
2903         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
2904         *ConstantInt::getBool(Ctx, 0));
2905     A.changeUseAfterManifest(
2906         KernelDeinitCB->getArgOperandUse(DeinitIsSPMDArgNo),
2907         *ConstantInt::getBool(Ctx, 1));
2908     A.changeUseAfterManifest(
2909         KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo),
2910         *ConstantInt::getBool(Ctx, 0));
2911     A.changeUseAfterManifest(
2912         KernelDeinitCB->getArgOperandUse(DeinitRequiresFullRuntimeArgNo),
2913         *ConstantInt::getBool(Ctx, 0));
2914 
2915     ++NumOpenMPTargetRegionKernelsSPMD;
2916 
2917     auto Remark = [&](OptimizationRemark OR) {
2918       return OR << "Transformed generic-mode kernel to SPMD-mode.";
2919     };
2920     A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
2921     return true;
2922   };
2923 
2924   ChangeStatus buildCustomStateMachine(Attributor &A) {
2925     assert(ReachedKnownParallelRegions.isValidState() &&
2926            "Custom state machine with invalid parallel region states?");
2927 
2928     const int InitIsSPMDArgNo = 1;
2929     const int InitUseStateMachineArgNo = 2;
2930 
2931     // Check if the current configuration is non-SPMD and generic state machine.
2932     // If we already have SPMD mode or a custom state machine we do not need to
2933     // go any further. If it is anything but a constant something is weird and
2934     // we give up.
2935     ConstantInt *UseStateMachine = dyn_cast<ConstantInt>(
2936         KernelInitCB->getArgOperand(InitUseStateMachineArgNo));
2937     ConstantInt *IsSPMD =
2938         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo));
2939 
2940     // If we are stuck with generic mode, try to create a custom device (=GPU)
2941     // state machine which is specialized for the parallel regions that are
2942     // reachable by the kernel.
2943     if (!UseStateMachine || UseStateMachine->isZero() || !IsSPMD ||
2944         !IsSPMD->isZero())
2945       return ChangeStatus::UNCHANGED;
2946 
2947     // If not SPMD mode, indicate we use a custom state machine now.
2948     auto &Ctx = getAnchorValue().getContext();
2949     auto *FalseVal = ConstantInt::getBool(Ctx, 0);
2950     A.changeUseAfterManifest(
2951         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal);
2952 
2953     // If we don't actually need a state machine we are done here. This can
2954     // happen if there simply are no parallel regions. In the resulting kernel
2955     // all worker threads will simply exit right away, leaving the main thread
2956     // to do the work alone.
2957     if (ReachedKnownParallelRegions.empty() &&
2958         ReachedUnknownParallelRegions.empty()) {
2959       ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
2960 
2961       auto Remark = [&](OptimizationRemark OR) {
2962         return OR << "Removing unused state machine from generic-mode kernel.";
2963       };
2964       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
2965 
2966       return ChangeStatus::CHANGED;
2967     }
2968 
2969     // Keep track in the statistics of our new shiny custom state machine.
2970     if (ReachedUnknownParallelRegions.empty()) {
2971       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
2972 
2973       auto Remark = [&](OptimizationRemark OR) {
2974         return OR << "Rewriting generic-mode kernel with a customized state "
2975                      "machine.";
2976       };
2977       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
2978     } else {
2979       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
2980 
2981       auto Remark = [&](OptimizationRemarkAnalysis OR) {
2982         return OR << "Generic-mode kernel is executed with a customized state "
2983                      "machine that requires a fallback.";
2984       };
2985       A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
2986 
2987       // Tell the user why we ended up with a fallback.
2988       for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
2989         if (!UnknownParallelRegionCB)
2990           continue;
2991         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2992           return ORA << "Call may contain unknown parallel regions. Use "
2993                      << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
2994                         "override.";
2995         };
2996         A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
2997                                                  "OMP133", Remark);
2998       }
2999     }
3000 
3001     // Create all the blocks:
3002     //
3003     //                       InitCB = __kmpc_target_init(...)
3004     //                       bool IsWorker = InitCB >= 0;
3005     //                       if (IsWorker) {
3006     // SMBeginBB:               __kmpc_barrier_simple_spmd(...);
3007     //                         void *WorkFn;
3008     //                         bool Active = __kmpc_kernel_parallel(&WorkFn);
3009     //                         if (!WorkFn) return;
3010     // SMIsActiveCheckBB:       if (Active) {
3011     // SMIfCascadeCurrentBB:      if      (WorkFn == <ParFn0>)
3012     //                              ParFn0(...);
3013     // SMIfCascadeCurrentBB:      else if (WorkFn == <ParFn1>)
3014     //                              ParFn1(...);
3015     //                            ...
3016     // SMIfCascadeCurrentBB:      else
3017     //                              ((WorkFnTy*)WorkFn)(...);
3018     // SMEndParallelBB:           __kmpc_kernel_end_parallel(...);
3019     //                          }
3020     // SMDoneBB:                __kmpc_barrier_simple_spmd(...);
3021     //                          goto SMBeginBB;
3022     //                       }
3023     // UserCodeEntryBB:      // user code
3024     //                       __kmpc_target_deinit(...)
3025     //
3026     Function *Kernel = getAssociatedFunction();
3027     assert(Kernel && "Expected an associated function!");
3028 
3029     BasicBlock *InitBB = KernelInitCB->getParent();
3030     BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
3031         KernelInitCB->getNextNode(), "thread.user_code.check");
3032     BasicBlock *StateMachineBeginBB = BasicBlock::Create(
3033         Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
3034     BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
3035         Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
3036     BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
3037         Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
3038     BasicBlock *StateMachineIfCascadeCurrentBB =
3039         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3040                            Kernel, UserCodeEntryBB);
3041     BasicBlock *StateMachineEndParallelBB =
3042         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
3043                            Kernel, UserCodeEntryBB);
3044     BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
3045         Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
3046     A.registerManifestAddedBasicBlock(*InitBB);
3047     A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
3048     A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
3049     A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
3050     A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
3051     A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
3052     A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
3053     A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
3054 
3055     const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
3056     ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
3057 
3058     InitBB->getTerminator()->eraseFromParent();
3059     Instruction *IsWorker =
3060         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
3061                          ConstantInt::get(KernelInitCB->getType(), -1),
3062                          "thread.is_worker", InitBB);
3063     IsWorker->setDebugLoc(DLoc);
3064     BranchInst::Create(StateMachineBeginBB, UserCodeEntryBB, IsWorker, InitBB);
3065 
3066     // Create local storage for the work function pointer.
3067     Type *VoidPtrTy = Type::getInt8PtrTy(Ctx);
3068     AllocaInst *WorkFnAI = new AllocaInst(VoidPtrTy, 0, "worker.work_fn.addr",
3069                                           &Kernel->getEntryBlock().front());
3070     WorkFnAI->setDebugLoc(DLoc);
3071 
3072     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3073     OMPInfoCache.OMPBuilder.updateToLocation(
3074         OpenMPIRBuilder::LocationDescription(
3075             IRBuilder<>::InsertPoint(StateMachineBeginBB,
3076                                      StateMachineBeginBB->end()),
3077             DLoc));
3078 
3079     Value *Ident = KernelInitCB->getArgOperand(0);
3080     Value *GTid = KernelInitCB;
3081 
3082     Module &M = *Kernel->getParent();
3083     FunctionCallee BarrierFn =
3084         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3085             M, OMPRTL___kmpc_barrier_simple_spmd);
3086     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB)
3087         ->setDebugLoc(DLoc);
3088 
3089     FunctionCallee KernelParallelFn =
3090         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3091             M, OMPRTL___kmpc_kernel_parallel);
3092     Instruction *IsActiveWorker = CallInst::Create(
3093         KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
3094     IsActiveWorker->setDebugLoc(DLoc);
3095     Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
3096                                        StateMachineBeginBB);
3097     WorkFn->setDebugLoc(DLoc);
3098 
3099     FunctionType *ParallelRegionFnTy = FunctionType::get(
3100         Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
3101         false);
3102     Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast(
3103         WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast",
3104         StateMachineBeginBB);
3105 
3106     Instruction *IsDone =
3107         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
3108                          Constant::getNullValue(VoidPtrTy), "worker.is_done",
3109                          StateMachineBeginBB);
3110     IsDone->setDebugLoc(DLoc);
3111     BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
3112                        IsDone, StateMachineBeginBB)
3113         ->setDebugLoc(DLoc);
3114 
3115     BranchInst::Create(StateMachineIfCascadeCurrentBB,
3116                        StateMachineDoneBarrierBB, IsActiveWorker,
3117                        StateMachineIsActiveCheckBB)
3118         ->setDebugLoc(DLoc);
3119 
3120     Value *ZeroArg =
3121         Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
3122 
3123     // Now that we have most of the CFG skeleton it is time for the if-cascade
3124     // that checks the function pointer we got from the runtime against the
3125     // parallel regions we expect, if there are any.
3126     for (int i = 0, e = ReachedKnownParallelRegions.size(); i < e; ++i) {
3127       auto *ParallelRegion = ReachedKnownParallelRegions[i];
3128       BasicBlock *PRExecuteBB = BasicBlock::Create(
3129           Ctx, "worker_state_machine.parallel_region.execute", Kernel,
3130           StateMachineEndParallelBB);
3131       CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
3132           ->setDebugLoc(DLoc);
3133       BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
3134           ->setDebugLoc(DLoc);
3135 
3136       BasicBlock *PRNextBB =
3137           BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3138                              Kernel, StateMachineEndParallelBB);
3139 
3140       // Check if we need to compare the pointer at all or if we can just
3141       // call the parallel region function.
3142       Value *IsPR;
3143       if (i + 1 < e || !ReachedUnknownParallelRegions.empty()) {
3144         Instruction *CmpI = ICmpInst::Create(
3145             ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion,
3146             "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
3147         CmpI->setDebugLoc(DLoc);
3148         IsPR = CmpI;
3149       } else {
3150         IsPR = ConstantInt::getTrue(Ctx);
3151       }
3152 
3153       BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
3154                          StateMachineIfCascadeCurrentBB)
3155           ->setDebugLoc(DLoc);
3156       StateMachineIfCascadeCurrentBB = PRNextBB;
3157     }
3158 
3159     // At the end of the if-cascade we place the indirect function pointer call
3160     // in case we might need it, that is if there can be parallel regions we
3161     // have not handled in the if-cascade above.
3162     if (!ReachedUnknownParallelRegions.empty()) {
3163       StateMachineIfCascadeCurrentBB->setName(
3164           "worker_state_machine.parallel_region.fallback.execute");
3165       CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "",
3166                        StateMachineIfCascadeCurrentBB)
3167           ->setDebugLoc(DLoc);
3168     }
3169     BranchInst::Create(StateMachineEndParallelBB,
3170                        StateMachineIfCascadeCurrentBB)
3171         ->setDebugLoc(DLoc);
3172 
3173     CallInst::Create(OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3174                          M, OMPRTL___kmpc_kernel_end_parallel),
3175                      {}, "", StateMachineEndParallelBB)
3176         ->setDebugLoc(DLoc);
3177     BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
3178         ->setDebugLoc(DLoc);
3179 
3180     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
3181         ->setDebugLoc(DLoc);
3182     BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
3183         ->setDebugLoc(DLoc);
3184 
3185     return ChangeStatus::CHANGED;
3186   }
3187 
3188   /// Fixpoint iteration update function. Will be called every time a dependence
3189   /// changed its state (and in the beginning).
3190   ChangeStatus updateImpl(Attributor &A) override {
3191     KernelInfoState StateBefore = getState();
3192 
3193     // Callback to check a read/write instruction.
3194     auto CheckRWInst = [&](Instruction &I) {
3195       // We handle calls later.
3196       if (isa<CallBase>(I))
3197         return true;
3198       // We only care about write effects.
3199       if (!I.mayWriteToMemory())
3200         return true;
3201       if (auto *SI = dyn_cast<StoreInst>(&I)) {
3202         SmallVector<const Value *> Objects;
3203         getUnderlyingObjects(SI->getPointerOperand(), Objects);
3204         if (llvm::all_of(Objects,
3205                          [](const Value *Obj) { return isa<AllocaInst>(Obj); }))
3206           return true;
3207       }
3208       // For now we give up on everything but stores.
3209       SPMDCompatibilityTracker.insert(&I);
3210       return true;
3211     };
3212 
3213     bool UsedAssumedInformationInCheckRWInst = false;
3214     if (!SPMDCompatibilityTracker.isAtFixpoint())
3215       if (!A.checkForAllReadWriteInstructions(
3216               CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
3217         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3218 
3219     if (!IsKernelEntry)
3220       updateReachingKernelEntries(A);
3221 
3222     // Callback to check a call instruction.
3223     bool AllSPMDStatesWereFixed = true;
3224     auto CheckCallInst = [&](Instruction &I) {
3225       auto &CB = cast<CallBase>(I);
3226       auto &CBAA = A.getAAFor<AAKernelInfo>(
3227           *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
3228       getState() ^= CBAA.getState();
3229       AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint();
3230       return true;
3231     };
3232 
3233     bool UsedAssumedInformationInCheckCallInst = false;
3234     if (!A.checkForAllCallLikeInstructions(
3235             CheckCallInst, *this, UsedAssumedInformationInCheckCallInst))
3236       return indicatePessimisticFixpoint();
3237 
3238     // If we haven't used any assumed information for the SPMD state we can fix
3239     // it.
3240     if (!UsedAssumedInformationInCheckRWInst &&
3241         !UsedAssumedInformationInCheckCallInst && AllSPMDStatesWereFixed)
3242       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3243 
3244     return StateBefore == getState() ? ChangeStatus::UNCHANGED
3245                                      : ChangeStatus::CHANGED;
3246   }
3247 
3248 private:
3249   /// Update info regarding reaching kernels.
3250   void updateReachingKernelEntries(Attributor &A) {
3251     auto PredCallSite = [&](AbstractCallSite ACS) {
3252       Function *Caller = ACS.getInstruction()->getFunction();
3253 
3254       assert(Caller && "Caller is nullptr");
3255 
3256       auto &CAA = A.getOrCreateAAFor<AAKernelInfo>(
3257           IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
3258       if (CAA.ReachingKernelEntries.isValidState()) {
3259         ReachingKernelEntries ^= CAA.ReachingKernelEntries;
3260         return true;
3261       }
3262 
3263       // We lost track of the caller of the associated function, any kernel
3264       // could reach now.
3265       ReachingKernelEntries.indicatePessimisticFixpoint();
3266 
3267       return true;
3268     };
3269 
3270     bool AllCallSitesKnown;
3271     if (!A.checkForAllCallSites(PredCallSite, *this,
3272                                 true /* RequireAllCallSites */,
3273                                 AllCallSitesKnown))
3274       ReachingKernelEntries.indicatePessimisticFixpoint();
3275   }
3276 };
3277 
3278 /// The call site kernel info abstract attribute, basically, what can we say
3279 /// about a call site with regards to the KernelInfoState. For now this simply
3280 /// forwards the information from the callee.
3281 struct AAKernelInfoCallSite : AAKernelInfo {
3282   AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
3283       : AAKernelInfo(IRP, A) {}
3284 
3285   /// See AbstractAttribute::initialize(...).
3286   void initialize(Attributor &A) override {
3287     AAKernelInfo::initialize(A);
3288 
3289     CallBase &CB = cast<CallBase>(getAssociatedValue());
3290     Function *Callee = getAssociatedFunction();
3291 
3292     // Helper to lookup an assumption string.
3293     auto HasAssumption = [](Function *Fn, StringRef AssumptionStr) {
3294       return Fn && hasAssumption(*Fn, AssumptionStr);
3295     };
3296 
3297     // Check for SPMD-mode assumptions.
3298     if (HasAssumption(Callee, "ompx_spmd_amenable"))
3299       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3300 
3301     // First weed out calls we do not care about, that is readonly/readnone
3302     // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
3303     // parallel region or anything else we are looking for.
3304     if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
3305       indicateOptimisticFixpoint();
3306       return;
3307     }
3308 
3309     // Next we check if we know the callee. If it is a known OpenMP function
3310     // we will handle them explicitly in the switch below. If it is not, we
3311     // will use an AAKernelInfo object on the callee to gather information and
3312     // merge that into the current state. The latter happens in the updateImpl.
3313     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3314     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
3315     if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
3316       // Unknown caller or declarations are not analyzable, we give up.
3317       if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
3318 
3319         // Unknown callees might contain parallel regions, except if they have
3320         // an appropriate assumption attached.
3321         if (!(HasAssumption(Callee, "omp_no_openmp") ||
3322               HasAssumption(Callee, "omp_no_parallelism")))
3323           ReachedUnknownParallelRegions.insert(&CB);
3324 
3325         // If SPMDCompatibilityTracker is not fixed, we need to give up on the
3326         // idea we can run something unknown in SPMD-mode.
3327         if (!SPMDCompatibilityTracker.isAtFixpoint())
3328           SPMDCompatibilityTracker.insert(&CB);
3329 
3330         // We have updated the state for this unknown call properly, there won't
3331         // be any change so we indicate a fixpoint.
3332         indicateOptimisticFixpoint();
3333       }
3334       // If the callee is known and can be used in IPO, we will update the state
3335       // based on the callee state in updateImpl.
3336       return;
3337     }
3338 
3339     const unsigned int WrapperFunctionArgNo = 6;
3340     RuntimeFunction RF = It->getSecond();
3341     switch (RF) {
3342     // All the functions we know are compatible with SPMD mode.
3343     case OMPRTL___kmpc_is_spmd_exec_mode:
3344     case OMPRTL___kmpc_for_static_fini:
3345     case OMPRTL___kmpc_global_thread_num:
3346     case OMPRTL___kmpc_single:
3347     case OMPRTL___kmpc_end_single:
3348     case OMPRTL___kmpc_master:
3349     case OMPRTL___kmpc_end_master:
3350     case OMPRTL___kmpc_barrier:
3351       break;
3352     case OMPRTL___kmpc_for_static_init_4:
3353     case OMPRTL___kmpc_for_static_init_4u:
3354     case OMPRTL___kmpc_for_static_init_8:
3355     case OMPRTL___kmpc_for_static_init_8u: {
3356       // Check the schedule and allow static schedule in SPMD mode.
3357       unsigned ScheduleArgOpNo = 2;
3358       auto *ScheduleTypeCI =
3359           dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
3360       unsigned ScheduleTypeVal =
3361           ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
3362       switch (OMPScheduleType(ScheduleTypeVal)) {
3363       case OMPScheduleType::Static:
3364       case OMPScheduleType::StaticChunked:
3365       case OMPScheduleType::Distribute:
3366       case OMPScheduleType::DistributeChunked:
3367         break;
3368       default:
3369         SPMDCompatibilityTracker.insert(&CB);
3370         break;
3371       };
3372     } break;
3373     case OMPRTL___kmpc_target_init:
3374       KernelInitCB = &CB;
3375       break;
3376     case OMPRTL___kmpc_target_deinit:
3377       KernelDeinitCB = &CB;
3378       break;
3379     case OMPRTL___kmpc_parallel_51:
3380       if (auto *ParallelRegion = dyn_cast<Function>(
3381               CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) {
3382         ReachedKnownParallelRegions.insert(ParallelRegion);
3383         break;
3384       }
3385       // The condition above should usually get the parallel region function
3386       // pointer and record it. In the off chance it doesn't we assume the
3387       // worst.
3388       ReachedUnknownParallelRegions.insert(&CB);
3389       break;
3390     case OMPRTL___kmpc_omp_task:
3391       // We do not look into tasks right now, just give up.
3392       SPMDCompatibilityTracker.insert(&CB);
3393       ReachedUnknownParallelRegions.insert(&CB);
3394       break;
3395     default:
3396       // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
3397       // generally.
3398       SPMDCompatibilityTracker.insert(&CB);
3399       break;
3400     }
3401     // All other OpenMP runtime calls will not reach parallel regions so they
3402     // can be safely ignored for now. Since it is a known OpenMP runtime call we
3403     // have now modeled all effects and there is no need for any update.
3404     indicateOptimisticFixpoint();
3405   }
3406 
3407   ChangeStatus updateImpl(Attributor &A) override {
3408     // TODO: Once we have call site specific value information we can provide
3409     //       call site specific liveness information and then it makes
3410     //       sense to specialize attributes for call sites arguments instead of
3411     //       redirecting requests to the callee argument.
3412     Function *F = getAssociatedFunction();
3413     const IRPosition &FnPos = IRPosition::function(*F);
3414     auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
3415     if (getState() == FnAA.getState())
3416       return ChangeStatus::UNCHANGED;
3417     getState() = FnAA.getState();
3418     return ChangeStatus::CHANGED;
3419   }
3420 };
3421 
3422 struct AAFoldRuntimeCall
3423     : public StateWrapper<BooleanState, AbstractAttribute> {
3424   using Base = StateWrapper<BooleanState, AbstractAttribute>;
3425 
3426   AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3427 
3428   /// Statistics are tracked as part of manifest for now.
3429   void trackStatistics() const override {}
3430 
3431   /// Create an abstract attribute biew for the position \p IRP.
3432   static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
3433                                               Attributor &A);
3434 
3435   /// See AbstractAttribute::getName()
3436   const std::string getName() const override { return "AAFoldRuntimeCall"; }
3437 
3438   /// See AbstractAttribute::getIdAddr()
3439   const char *getIdAddr() const override { return &ID; }
3440 
3441   /// This function should return true if the type of the \p AA is
3442   /// AAFoldRuntimeCall
3443   static bool classof(const AbstractAttribute *AA) {
3444     return (AA->getIdAddr() == &ID);
3445   }
3446 
3447   static const char ID;
3448 };
3449 
3450 struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
3451   AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
3452       : AAFoldRuntimeCall(IRP, A) {}
3453 
3454   /// See AbstractAttribute::getAsStr()
3455   const std::string getAsStr() const override {
3456     if (!isValidState())
3457       return "<invalid>";
3458 
3459     std::string Str("simplified value: ");
3460 
3461     if (!SimplifiedValue.hasValue())
3462       return Str + std::string("none");
3463 
3464     if (!SimplifiedValue.getValue())
3465       return Str + std::string("nullptr");
3466 
3467     if (ConstantInt *CI = dyn_cast<ConstantInt>(SimplifiedValue.getValue()))
3468       return Str + std::to_string(CI->getSExtValue());
3469 
3470     return Str + std::string("unknown");
3471   }
3472 
3473   void initialize(Attributor &A) override {
3474     Function *Callee = getAssociatedFunction();
3475 
3476     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3477     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
3478     assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
3479            "Expected a known OpenMP runtime function");
3480 
3481     RFKind = It->getSecond();
3482 
3483     CallBase &CB = cast<CallBase>(getAssociatedValue());
3484     A.registerSimplificationCallback(
3485         IRPosition::callsite_returned(CB),
3486         [&](const IRPosition &IRP, const AbstractAttribute *AA,
3487             bool &UsedAssumedInformation) -> Optional<Value *> {
3488           assert((isValidState() || (SimplifiedValue.hasValue() &&
3489                                      SimplifiedValue.getValue() == nullptr)) &&
3490                  "Unexpected invalid state!");
3491 
3492           if (!isAtFixpoint()) {
3493             UsedAssumedInformation = true;
3494             if (AA)
3495               A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
3496           }
3497           return SimplifiedValue;
3498         });
3499   }
3500 
3501   ChangeStatus updateImpl(Attributor &A) override {
3502     ChangeStatus Changed = ChangeStatus::UNCHANGED;
3503 
3504     switch (RFKind) {
3505     case OMPRTL___kmpc_is_spmd_exec_mode:
3506       Changed |= foldIsSPMDExecMode(A);
3507       break;
3508     default:
3509       llvm_unreachable("Unhandled OpenMP runtime function!");
3510     }
3511 
3512     return Changed;
3513   }
3514 
3515   ChangeStatus manifest(Attributor &A) override {
3516     ChangeStatus Changed = ChangeStatus::UNCHANGED;
3517 
3518     if (SimplifiedValue.hasValue() && SimplifiedValue.getValue()) {
3519       Instruction &CB = *getCtxI();
3520       A.changeValueAfterManifest(CB, **SimplifiedValue);
3521       A.deleteAfterManifest(CB);
3522       Changed = ChangeStatus::CHANGED;
3523     }
3524 
3525     return Changed;
3526   }
3527 
3528   ChangeStatus indicatePessimisticFixpoint() override {
3529     SimplifiedValue = nullptr;
3530     return AAFoldRuntimeCall::indicatePessimisticFixpoint();
3531   }
3532 
3533 private:
3534   /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
3535   ChangeStatus foldIsSPMDExecMode(Attributor &A) {
3536     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
3537 
3538     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
3539     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
3540     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
3541         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
3542 
3543     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
3544       return indicatePessimisticFixpoint();
3545 
3546     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
3547       auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
3548                                           DepClassTy::REQUIRED);
3549 
3550       if (!AA.isValidState()) {
3551         SimplifiedValue = nullptr;
3552         return indicatePessimisticFixpoint();
3553       }
3554 
3555       if (AA.SPMDCompatibilityTracker.isAssumed()) {
3556         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
3557           ++KnownSPMDCount;
3558         else
3559           ++AssumedSPMDCount;
3560       } else {
3561         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
3562           ++KnownNonSPMDCount;
3563         else
3564           ++AssumedNonSPMDCount;
3565       }
3566     }
3567 
3568     if (KnownSPMDCount && KnownNonSPMDCount)
3569       return indicatePessimisticFixpoint();
3570 
3571     if (AssumedSPMDCount && AssumedNonSPMDCount)
3572       return indicatePessimisticFixpoint();
3573 
3574     auto &Ctx = getAnchorValue().getContext();
3575     if (KnownSPMDCount || AssumedSPMDCount) {
3576       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
3577              "Expected only SPMD kernels!");
3578       // All reaching kernels are in SPMD mode. Update all function calls to
3579       // __kmpc_is_spmd_exec_mode to 1.
3580       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
3581     } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
3582       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
3583              "Expected only non-SPMD kernels!");
3584       // All reaching kernels are in non-SPMD mode. Update all function
3585       // calls to __kmpc_is_spmd_exec_mode to 0.
3586       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
3587     } else {
3588       // We have empty reaching kernels, therefore we cannot tell if the
3589       // associated call site can be folded. At this moment, SimplifiedValue
3590       // must be none.
3591       assert(!SimplifiedValue.hasValue() && "SimplifiedValue should be none");
3592     }
3593 
3594     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
3595                                                     : ChangeStatus::CHANGED;
3596   }
3597 
3598   /// An optional value the associated value is assumed to fold to. That is, we
3599   /// assume the associated value (which is a call) can be replaced by this
3600   /// simplified value.
3601   Optional<Value *> SimplifiedValue;
3602 
3603   /// The runtime function kind of the callee of the associated call site.
3604   RuntimeFunction RFKind;
3605 };
3606 
3607 } // namespace
3608 
3609 void OpenMPOpt::registerAAs(bool IsModulePass) {
3610   if (SCC.empty())
3611 
3612     return;
3613   if (IsModulePass) {
3614     // Ensure we create the AAKernelInfo AAs first and without triggering an
3615     // update. This will make sure we register all value simplification
3616     // callbacks before any other AA has the chance to create an AAValueSimplify
3617     // or similar.
3618     for (Function *Kernel : OMPInfoCache.Kernels)
3619       A.getOrCreateAAFor<AAKernelInfo>(
3620           IRPosition::function(*Kernel), /* QueryingAA */ nullptr,
3621           DepClassTy::NONE, /* ForceUpdate */ false,
3622           /* UpdateAfterInit */ false);
3623 
3624     auto &IsSPMDRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_is_spmd_exec_mode];
3625     IsSPMDRFI.foreachUse(SCC, [&](Use &U, Function &) {
3626       CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &IsSPMDRFI);
3627       if (!CI)
3628         return false;
3629       A.getOrCreateAAFor<AAFoldRuntimeCall>(
3630           IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
3631           DepClassTy::NONE, /* ForceUpdate */ false,
3632           /* UpdateAfterInit */ false);
3633       return false;
3634     });
3635   }
3636 
3637   // Create CallSite AA for all Getters.
3638   for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
3639     auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
3640 
3641     auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
3642 
3643     auto CreateAA = [&](Use &U, Function &Caller) {
3644       CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
3645       if (!CI)
3646         return false;
3647 
3648       auto &CB = cast<CallBase>(*CI);
3649 
3650       IRPosition CBPos = IRPosition::callsite_function(CB);
3651       A.getOrCreateAAFor<AAICVTracker>(CBPos);
3652       return false;
3653     };
3654 
3655     GetterRFI.foreachUse(SCC, CreateAA);
3656   }
3657   auto &GlobalizationRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3658   auto CreateAA = [&](Use &U, Function &F) {
3659     A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
3660     return false;
3661   };
3662   GlobalizationRFI.foreachUse(SCC, CreateAA);
3663 
3664   // Create an ExecutionDomain AA for every function and a HeapToStack AA for
3665   // every function if there is a device kernel.
3666   for (auto *F : SCC) {
3667     if (!F->isDeclaration())
3668       A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F));
3669     if (isOpenMPDevice(M))
3670       A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F));
3671   }
3672 }
3673 
3674 const char AAICVTracker::ID = 0;
3675 const char AAKernelInfo::ID = 0;
3676 const char AAExecutionDomain::ID = 0;
3677 const char AAHeapToShared::ID = 0;
3678 const char AAFoldRuntimeCall::ID = 0;
3679 
3680 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
3681                                               Attributor &A) {
3682   AAICVTracker *AA = nullptr;
3683   switch (IRP.getPositionKind()) {
3684   case IRPosition::IRP_INVALID:
3685   case IRPosition::IRP_FLOAT:
3686   case IRPosition::IRP_ARGUMENT:
3687   case IRPosition::IRP_CALL_SITE_ARGUMENT:
3688     llvm_unreachable("ICVTracker can only be created for function position!");
3689   case IRPosition::IRP_RETURNED:
3690     AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
3691     break;
3692   case IRPosition::IRP_CALL_SITE_RETURNED:
3693     AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
3694     break;
3695   case IRPosition::IRP_CALL_SITE:
3696     AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
3697     break;
3698   case IRPosition::IRP_FUNCTION:
3699     AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
3700     break;
3701   }
3702 
3703   return *AA;
3704 }
3705 
3706 AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP,
3707                                                         Attributor &A) {
3708   AAExecutionDomainFunction *AA = nullptr;
3709   switch (IRP.getPositionKind()) {
3710   case IRPosition::IRP_INVALID:
3711   case IRPosition::IRP_FLOAT:
3712   case IRPosition::IRP_ARGUMENT:
3713   case IRPosition::IRP_CALL_SITE_ARGUMENT:
3714   case IRPosition::IRP_RETURNED:
3715   case IRPosition::IRP_CALL_SITE_RETURNED:
3716   case IRPosition::IRP_CALL_SITE:
3717     llvm_unreachable(
3718         "AAExecutionDomain can only be created for function position!");
3719   case IRPosition::IRP_FUNCTION:
3720     AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
3721     break;
3722   }
3723 
3724   return *AA;
3725 }
3726 
3727 AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
3728                                                   Attributor &A) {
3729   AAHeapToSharedFunction *AA = nullptr;
3730   switch (IRP.getPositionKind()) {
3731   case IRPosition::IRP_INVALID:
3732   case IRPosition::IRP_FLOAT:
3733   case IRPosition::IRP_ARGUMENT:
3734   case IRPosition::IRP_CALL_SITE_ARGUMENT:
3735   case IRPosition::IRP_RETURNED:
3736   case IRPosition::IRP_CALL_SITE_RETURNED:
3737   case IRPosition::IRP_CALL_SITE:
3738     llvm_unreachable(
3739         "AAHeapToShared can only be created for function position!");
3740   case IRPosition::IRP_FUNCTION:
3741     AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
3742     break;
3743   }
3744 
3745   return *AA;
3746 }
3747 
3748 AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
3749                                               Attributor &A) {
3750   AAKernelInfo *AA = nullptr;
3751   switch (IRP.getPositionKind()) {
3752   case IRPosition::IRP_INVALID:
3753   case IRPosition::IRP_FLOAT:
3754   case IRPosition::IRP_ARGUMENT:
3755   case IRPosition::IRP_RETURNED:
3756   case IRPosition::IRP_CALL_SITE_RETURNED:
3757   case IRPosition::IRP_CALL_SITE_ARGUMENT:
3758     llvm_unreachable("KernelInfo can only be created for function position!");
3759   case IRPosition::IRP_CALL_SITE:
3760     AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
3761     break;
3762   case IRPosition::IRP_FUNCTION:
3763     AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
3764     break;
3765   }
3766 
3767   return *AA;
3768 }
3769 
3770 AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
3771                                                         Attributor &A) {
3772   AAFoldRuntimeCall *AA = nullptr;
3773   switch (IRP.getPositionKind()) {
3774   case IRPosition::IRP_INVALID:
3775   case IRPosition::IRP_FLOAT:
3776   case IRPosition::IRP_ARGUMENT:
3777   case IRPosition::IRP_RETURNED:
3778   case IRPosition::IRP_FUNCTION:
3779   case IRPosition::IRP_CALL_SITE:
3780   case IRPosition::IRP_CALL_SITE_ARGUMENT:
3781     llvm_unreachable("KernelInfo can only be created for call site position!");
3782   case IRPosition::IRP_CALL_SITE_RETURNED:
3783     AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
3784     break;
3785   }
3786 
3787   return *AA;
3788 }
3789 
3790 PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
3791   if (!containsOpenMP(M))
3792     return PreservedAnalyses::all();
3793   if (DisableOpenMPOptimizations)
3794     return PreservedAnalyses::all();
3795 
3796   FunctionAnalysisManager &FAM =
3797       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
3798   KernelSet Kernels = getDeviceKernels(M);
3799 
3800   auto IsCalled = [&](Function &F) {
3801     if (Kernels.contains(&F))
3802       return true;
3803     for (const User *U : F.users())
3804       if (!isa<BlockAddress>(U))
3805         return true;
3806     return false;
3807   };
3808 
3809   auto EmitRemark = [&](Function &F) {
3810     auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
3811     ORE.emit([&]() {
3812       OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
3813       return ORA << "Could not internalize function. "
3814                  << "Some optimizations may not be possible.";
3815     });
3816   };
3817 
3818   // Create internal copies of each function if this is a kernel Module. This
3819   // allows iterprocedural passes to see every call edge.
3820   DenseSet<const Function *> InternalizedFuncs;
3821   if (isOpenMPDevice(M))
3822     for (Function &F : M)
3823       if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F)) {
3824         if (Attributor::internalizeFunction(F, /* Force */ true)) {
3825           InternalizedFuncs.insert(&F);
3826         } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
3827           EmitRemark(F);
3828         }
3829       }
3830 
3831   // Look at every function in the Module unless it was internalized.
3832   SmallVector<Function *, 16> SCC;
3833   for (Function &F : M)
3834     if (!F.isDeclaration() && !InternalizedFuncs.contains(&F))
3835       SCC.push_back(&F);
3836 
3837   if (SCC.empty())
3838     return PreservedAnalyses::all();
3839 
3840   AnalysisGetter AG(FAM);
3841 
3842   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
3843     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
3844   };
3845 
3846   BumpPtrAllocator Allocator;
3847   CallGraphUpdater CGUpdater;
3848 
3849   SetVector<Function *> Functions(SCC.begin(), SCC.end());
3850   OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels);
3851 
3852   unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
3853   Attributor A(Functions, InfoCache, CGUpdater, nullptr, true, false,
3854                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
3855 
3856   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
3857   bool Changed = OMPOpt.run(true);
3858   if (Changed)
3859     return PreservedAnalyses::none();
3860 
3861   return PreservedAnalyses::all();
3862 }
3863 
3864 PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
3865                                           CGSCCAnalysisManager &AM,
3866                                           LazyCallGraph &CG,
3867                                           CGSCCUpdateResult &UR) {
3868   if (!containsOpenMP(*C.begin()->getFunction().getParent()))
3869     return PreservedAnalyses::all();
3870   if (DisableOpenMPOptimizations)
3871     return PreservedAnalyses::all();
3872 
3873   SmallVector<Function *, 16> SCC;
3874   // If there are kernels in the module, we have to run on all SCC's.
3875   for (LazyCallGraph::Node &N : C) {
3876     Function *Fn = &N.getFunction();
3877     SCC.push_back(Fn);
3878   }
3879 
3880   if (SCC.empty())
3881     return PreservedAnalyses::all();
3882 
3883   Module &M = *C.begin()->getFunction().getParent();
3884 
3885   KernelSet Kernels = getDeviceKernels(M);
3886 
3887   FunctionAnalysisManager &FAM =
3888       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
3889 
3890   AnalysisGetter AG(FAM);
3891 
3892   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
3893     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
3894   };
3895 
3896   BumpPtrAllocator Allocator;
3897   CallGraphUpdater CGUpdater;
3898   CGUpdater.initialize(CG, C, AM, UR);
3899 
3900   SetVector<Function *> Functions(SCC.begin(), SCC.end());
3901   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
3902                                 /*CGSCC*/ Functions, Kernels);
3903 
3904   unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
3905   Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
3906                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
3907 
3908   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
3909   bool Changed = OMPOpt.run(false);
3910   if (Changed)
3911     return PreservedAnalyses::none();
3912 
3913   return PreservedAnalyses::all();
3914 }
3915 
3916 namespace {
3917 
3918 struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass {
3919   CallGraphUpdater CGUpdater;
3920   static char ID;
3921 
3922   OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) {
3923     initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry());
3924   }
3925 
3926   void getAnalysisUsage(AnalysisUsage &AU) const override {
3927     CallGraphSCCPass::getAnalysisUsage(AU);
3928   }
3929 
3930   bool runOnSCC(CallGraphSCC &CGSCC) override {
3931     if (!containsOpenMP(CGSCC.getCallGraph().getModule()))
3932       return false;
3933     if (DisableOpenMPOptimizations || skipSCC(CGSCC))
3934       return false;
3935 
3936     SmallVector<Function *, 16> SCC;
3937     // If there are kernels in the module, we have to run on all SCC's.
3938     for (CallGraphNode *CGN : CGSCC) {
3939       Function *Fn = CGN->getFunction();
3940       if (!Fn || Fn->isDeclaration())
3941         continue;
3942       SCC.push_back(Fn);
3943     }
3944 
3945     if (SCC.empty())
3946       return false;
3947 
3948     Module &M = CGSCC.getCallGraph().getModule();
3949     KernelSet Kernels = getDeviceKernels(M);
3950 
3951     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
3952     CGUpdater.initialize(CG, CGSCC);
3953 
3954     // Maintain a map of functions to avoid rebuilding the ORE
3955     DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap;
3956     auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
3957       std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
3958       if (!ORE)
3959         ORE = std::make_unique<OptimizationRemarkEmitter>(F);
3960       return *ORE;
3961     };
3962 
3963     AnalysisGetter AG;
3964     SetVector<Function *> Functions(SCC.begin(), SCC.end());
3965     BumpPtrAllocator Allocator;
3966     OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG,
3967                                   Allocator,
3968                                   /*CGSCC*/ Functions, Kernels);
3969 
3970     unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
3971     Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
3972                  MaxFixpointIterations, OREGetter, DEBUG_TYPE);
3973 
3974     OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
3975     return OMPOpt.run(false);
3976   }
3977 
3978   bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
3979 };
3980 
3981 } // end anonymous namespace
3982 
3983 KernelSet llvm::omp::getDeviceKernels(Module &M) {
3984   // TODO: Create a more cross-platform way of determining device kernels.
3985   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
3986   KernelSet Kernels;
3987 
3988   if (!MD)
3989     return Kernels;
3990 
3991   for (auto *Op : MD->operands()) {
3992     if (Op->getNumOperands() < 2)
3993       continue;
3994     MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
3995     if (!KindID || KindID->getString() != "kernel")
3996       continue;
3997 
3998     Function *KernelFn =
3999         mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
4000     if (!KernelFn)
4001       continue;
4002 
4003     ++NumOpenMPTargetRegionKernels;
4004 
4005     Kernels.insert(KernelFn);
4006   }
4007 
4008   return Kernels;
4009 }
4010 
4011 bool llvm::omp::containsOpenMP(Module &M) {
4012   Metadata *MD = M.getModuleFlag("openmp");
4013   if (!MD)
4014     return false;
4015 
4016   return true;
4017 }
4018 
4019 bool llvm::omp::isOpenMPDevice(Module &M) {
4020   Metadata *MD = M.getModuleFlag("openmp-device");
4021   if (!MD)
4022     return false;
4023 
4024   return true;
4025 }
4026 
4027 char OpenMPOptCGSCCLegacyPass::ID = 0;
4028 
4029 INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
4030                       "OpenMP specific optimizations", false, false)
4031 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
4032 INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
4033                     "OpenMP specific optimizations", false, false)
4034 
4035 Pass *llvm::createOpenMPOptCGSCCLegacyPass() {
4036   return new OpenMPOptCGSCCLegacyPass();
4037 }
4038