1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpenMP specific optimizations:
10 //
11 // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12 // - Replacing globalized device memory with stack memory.
13 // - Replacing globalized device memory with shared memory.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/Transforms/IPO/OpenMPOpt.h"
18 
19 #include "llvm/ADT/EnumeratedArray.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/Statistic.h"
22 #include "llvm/Analysis/CallGraph.h"
23 #include "llvm/Analysis/CallGraphSCCPass.h"
24 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
25 #include "llvm/Analysis/ValueTracking.h"
26 #include "llvm/Frontend/OpenMP/OMPConstants.h"
27 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
28 #include "llvm/IR/Assumptions.h"
29 #include "llvm/IR/DiagnosticInfo.h"
30 #include "llvm/IR/GlobalValue.h"
31 #include "llvm/IR/Instruction.h"
32 #include "llvm/IR/IntrinsicInst.h"
33 #include "llvm/InitializePasses.h"
34 #include "llvm/Support/CommandLine.h"
35 #include "llvm/Transforms/IPO.h"
36 #include "llvm/Transforms/IPO/Attributor.h"
37 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
38 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
39 #include "llvm/Transforms/Utils/CodeExtractor.h"
40 
41 using namespace llvm;
42 using namespace omp;
43 
44 #define DEBUG_TYPE "openmp-opt"
45 
46 static cl::opt<bool> DisableOpenMPOptimizations(
47     "openmp-opt-disable", cl::ZeroOrMore,
48     cl::desc("Disable OpenMP specific optimizations."), cl::Hidden,
49     cl::init(false));
50 
51 static cl::opt<bool> EnableParallelRegionMerging(
52     "openmp-opt-enable-merging", cl::ZeroOrMore,
53     cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
54     cl::init(false));
55 
56 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
57                                     cl::Hidden);
58 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
59                                         cl::init(false), cl::Hidden);
60 
61 static cl::opt<bool> HideMemoryTransferLatency(
62     "openmp-hide-memory-transfer-latency",
63     cl::desc("[WIP] Tries to hide the latency of host to device memory"
64              " transfers"),
65     cl::Hidden, cl::init(false));
66 
67 STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
68           "Number of OpenMP runtime calls deduplicated");
69 STATISTIC(NumOpenMPParallelRegionsDeleted,
70           "Number of OpenMP parallel regions deleted");
71 STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
72           "Number of OpenMP runtime functions identified");
73 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
74           "Number of OpenMP runtime function uses identified");
75 STATISTIC(NumOpenMPTargetRegionKernels,
76           "Number of OpenMP target region entry points (=kernels) identified");
77 STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
78           "Number of OpenMP target region entry points (=kernels) executed in "
79           "SPMD-mode instead of generic-mode");
80 STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
81           "Number of OpenMP target region entry points (=kernels) executed in "
82           "generic-mode without a state machines");
83 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
84           "Number of OpenMP target region entry points (=kernels) executed in "
85           "generic-mode with customized state machines with fallback");
86 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
87           "Number of OpenMP target region entry points (=kernels) executed in "
88           "generic-mode with customized state machines without fallback");
89 STATISTIC(
90     NumOpenMPParallelRegionsReplacedInGPUStateMachine,
91     "Number of OpenMP parallel regions replaced with ID in GPU state machines");
92 STATISTIC(NumOpenMPParallelRegionsMerged,
93           "Number of OpenMP parallel regions merged");
94 STATISTIC(NumBytesMovedToSharedMemory,
95           "Amount of memory pushed to shared memory");
96 
97 #if !defined(NDEBUG)
98 static constexpr auto TAG = "[" DEBUG_TYPE "]";
99 #endif
100 
101 namespace {
102 
103 enum class AddressSpace : unsigned {
104   Generic = 0,
105   Global = 1,
106   Shared = 3,
107   Constant = 4,
108   Local = 5,
109 };
110 
111 struct AAHeapToShared;
112 
113 struct AAICVTracker;
114 
115 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
116 /// Attributor runs.
117 struct OMPInformationCache : public InformationCache {
118   OMPInformationCache(Module &M, AnalysisGetter &AG,
119                       BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC,
120                       SmallPtrSetImpl<Kernel> &Kernels)
121       : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M),
122         Kernels(Kernels) {
123 
124     OMPBuilder.initialize();
125     initializeRuntimeFunctions();
126     initializeInternalControlVars();
127   }
128 
129   /// Generic information that describes an internal control variable.
130   struct InternalControlVarInfo {
131     /// The kind, as described by InternalControlVar enum.
132     InternalControlVar Kind;
133 
134     /// The name of the ICV.
135     StringRef Name;
136 
137     /// Environment variable associated with this ICV.
138     StringRef EnvVarName;
139 
140     /// Initial value kind.
141     ICVInitValue InitKind;
142 
143     /// Initial value.
144     ConstantInt *InitValue;
145 
146     /// Setter RTL function associated with this ICV.
147     RuntimeFunction Setter;
148 
149     /// Getter RTL function associated with this ICV.
150     RuntimeFunction Getter;
151 
152     /// RTL Function corresponding to the override clause of this ICV
153     RuntimeFunction Clause;
154   };
155 
156   /// Generic information that describes a runtime function
157   struct RuntimeFunctionInfo {
158 
159     /// The kind, as described by the RuntimeFunction enum.
160     RuntimeFunction Kind;
161 
162     /// The name of the function.
163     StringRef Name;
164 
165     /// Flag to indicate a variadic function.
166     bool IsVarArg;
167 
168     /// The return type of the function.
169     Type *ReturnType;
170 
171     /// The argument types of the function.
172     SmallVector<Type *, 8> ArgumentTypes;
173 
174     /// The declaration if available.
175     Function *Declaration = nullptr;
176 
177     /// Uses of this runtime function per function containing the use.
178     using UseVector = SmallVector<Use *, 16>;
179 
180     /// Clear UsesMap for runtime function.
181     void clearUsesMap() { UsesMap.clear(); }
182 
183     /// Boolean conversion that is true if the runtime function was found.
184     operator bool() const { return Declaration; }
185 
186     /// Return the vector of uses in function \p F.
187     UseVector &getOrCreateUseVector(Function *F) {
188       std::shared_ptr<UseVector> &UV = UsesMap[F];
189       if (!UV)
190         UV = std::make_shared<UseVector>();
191       return *UV;
192     }
193 
194     /// Return the vector of uses in function \p F or `nullptr` if there are
195     /// none.
196     const UseVector *getUseVector(Function &F) const {
197       auto I = UsesMap.find(&F);
198       if (I != UsesMap.end())
199         return I->second.get();
200       return nullptr;
201     }
202 
203     /// Return how many functions contain uses of this runtime function.
204     size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
205 
206     /// Return the number of arguments (or the minimal number for variadic
207     /// functions).
208     size_t getNumArgs() const { return ArgumentTypes.size(); }
209 
210     /// Run the callback \p CB on each use and forget the use if the result is
211     /// true. The callback will be fed the function in which the use was
212     /// encountered as second argument.
213     void foreachUse(SmallVectorImpl<Function *> &SCC,
214                     function_ref<bool(Use &, Function &)> CB) {
215       for (Function *F : SCC)
216         foreachUse(CB, F);
217     }
218 
219     /// Run the callback \p CB on each use within the function \p F and forget
220     /// the use if the result is true.
221     void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
222       SmallVector<unsigned, 8> ToBeDeleted;
223       ToBeDeleted.clear();
224 
225       unsigned Idx = 0;
226       UseVector &UV = getOrCreateUseVector(F);
227 
228       for (Use *U : UV) {
229         if (CB(*U, *F))
230           ToBeDeleted.push_back(Idx);
231         ++Idx;
232       }
233 
234       // Remove the to-be-deleted indices in reverse order as prior
235       // modifications will not modify the smaller indices.
236       while (!ToBeDeleted.empty()) {
237         unsigned Idx = ToBeDeleted.pop_back_val();
238         UV[Idx] = UV.back();
239         UV.pop_back();
240       }
241     }
242 
243   private:
244     /// Map from functions to all uses of this runtime function contained in
245     /// them.
246     DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
247 
248   public:
249     /// Iterators for the uses of this runtime function.
250     decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
251     decltype(UsesMap)::iterator end() { return UsesMap.end(); }
252   };
253 
254   /// An OpenMP-IR-Builder instance
255   OpenMPIRBuilder OMPBuilder;
256 
257   /// Map from runtime function kind to the runtime function description.
258   EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
259                   RuntimeFunction::OMPRTL___last>
260       RFIs;
261 
262   /// Map from function declarations/definitions to their runtime enum type.
263   DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
264 
265   /// Map from ICV kind to the ICV description.
266   EnumeratedArray<InternalControlVarInfo, InternalControlVar,
267                   InternalControlVar::ICV___last>
268       ICVs;
269 
270   /// Helper to initialize all internal control variable information for those
271   /// defined in OMPKinds.def.
272   void initializeInternalControlVars() {
273 #define ICV_RT_SET(_Name, RTL)                                                 \
274   {                                                                            \
275     auto &ICV = ICVs[_Name];                                                   \
276     ICV.Setter = RTL;                                                          \
277   }
278 #define ICV_RT_GET(Name, RTL)                                                  \
279   {                                                                            \
280     auto &ICV = ICVs[Name];                                                    \
281     ICV.Getter = RTL;                                                          \
282   }
283 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init)                           \
284   {                                                                            \
285     auto &ICV = ICVs[Enum];                                                    \
286     ICV.Name = _Name;                                                          \
287     ICV.Kind = Enum;                                                           \
288     ICV.InitKind = Init;                                                       \
289     ICV.EnvVarName = _EnvVarName;                                              \
290     switch (ICV.InitKind) {                                                    \
291     case ICV_IMPLEMENTATION_DEFINED:                                           \
292       ICV.InitValue = nullptr;                                                 \
293       break;                                                                   \
294     case ICV_ZERO:                                                             \
295       ICV.InitValue = ConstantInt::get(                                        \
296           Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0);                \
297       break;                                                                   \
298     case ICV_FALSE:                                                            \
299       ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext());    \
300       break;                                                                   \
301     case ICV_LAST:                                                             \
302       break;                                                                   \
303     }                                                                          \
304   }
305 #include "llvm/Frontend/OpenMP/OMPKinds.def"
306   }
307 
308   /// Returns true if the function declaration \p F matches the runtime
309   /// function types, that is, return type \p RTFRetType, and argument types
310   /// \p RTFArgTypes.
311   static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
312                                   SmallVector<Type *, 8> &RTFArgTypes) {
313     // TODO: We should output information to the user (under debug output
314     //       and via remarks).
315 
316     if (!F)
317       return false;
318     if (F->getReturnType() != RTFRetType)
319       return false;
320     if (F->arg_size() != RTFArgTypes.size())
321       return false;
322 
323     auto RTFTyIt = RTFArgTypes.begin();
324     for (Argument &Arg : F->args()) {
325       if (Arg.getType() != *RTFTyIt)
326         return false;
327 
328       ++RTFTyIt;
329     }
330 
331     return true;
332   }
333 
334   // Helper to collect all uses of the declaration in the UsesMap.
335   unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
336     unsigned NumUses = 0;
337     if (!RFI.Declaration)
338       return NumUses;
339     OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
340 
341     if (CollectStats) {
342       NumOpenMPRuntimeFunctionsIdentified += 1;
343       NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
344     }
345 
346     // TODO: We directly convert uses into proper calls and unknown uses.
347     for (Use &U : RFI.Declaration->uses()) {
348       if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
349         if (ModuleSlice.count(UserI->getFunction())) {
350           RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
351           ++NumUses;
352         }
353       } else {
354         RFI.getOrCreateUseVector(nullptr).push_back(&U);
355         ++NumUses;
356       }
357     }
358     return NumUses;
359   }
360 
361   // Helper function to recollect uses of a runtime function.
362   void recollectUsesForFunction(RuntimeFunction RTF) {
363     auto &RFI = RFIs[RTF];
364     RFI.clearUsesMap();
365     collectUses(RFI, /*CollectStats*/ false);
366   }
367 
368   // Helper function to recollect uses of all runtime functions.
369   void recollectUses() {
370     for (int Idx = 0; Idx < RFIs.size(); ++Idx)
371       recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
372   }
373 
374   /// Helper to initialize all runtime function information for those defined
375   /// in OpenMPKinds.def.
376   void initializeRuntimeFunctions() {
377     Module &M = *((*ModuleSlice.begin())->getParent());
378 
379     // Helper macros for handling __VA_ARGS__ in OMP_RTL
380 #define OMP_TYPE(VarName, ...)                                                 \
381   Type *VarName = OMPBuilder.VarName;                                          \
382   (void)VarName;
383 
384 #define OMP_ARRAY_TYPE(VarName, ...)                                           \
385   ArrayType *VarName##Ty = OMPBuilder.VarName##Ty;                             \
386   (void)VarName##Ty;                                                           \
387   PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy;                     \
388   (void)VarName##PtrTy;
389 
390 #define OMP_FUNCTION_TYPE(VarName, ...)                                        \
391   FunctionType *VarName = OMPBuilder.VarName;                                  \
392   (void)VarName;                                                               \
393   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
394   (void)VarName##Ptr;
395 
396 #define OMP_STRUCT_TYPE(VarName, ...)                                          \
397   StructType *VarName = OMPBuilder.VarName;                                    \
398   (void)VarName;                                                               \
399   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
400   (void)VarName##Ptr;
401 
402 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
403   {                                                                            \
404     SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
405     Function *F = M.getFunction(_Name);                                        \
406     if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) {           \
407       RuntimeFunctionIDMap[F] = _Enum;                                         \
408       auto &RFI = RFIs[_Enum];                                                 \
409       RFI.Kind = _Enum;                                                        \
410       RFI.Name = _Name;                                                        \
411       RFI.IsVarArg = _IsVarArg;                                                \
412       RFI.ReturnType = OMPBuilder._ReturnType;                                 \
413       RFI.ArgumentTypes = std::move(ArgsTypes);                                \
414       RFI.Declaration = F;                                                     \
415       unsigned NumUses = collectUses(RFI);                                     \
416       (void)NumUses;                                                           \
417       LLVM_DEBUG({                                                             \
418         dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
419                << " found\n";                                                  \
420         if (RFI.Declaration)                                                   \
421           dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
422                  << RFI.getNumFunctionsWithUses()                              \
423                  << " different functions.\n";                                 \
424       });                                                                      \
425     }                                                                          \
426   }
427 #include "llvm/Frontend/OpenMP/OMPKinds.def"
428 
429     // TODO: We should attach the attributes defined in OMPKinds.def.
430   }
431 
432   /// Collection of known kernels (\see Kernel) in the module.
433   SmallPtrSetImpl<Kernel> &Kernels;
434 };
435 
436 template <typename Ty, bool InsertInvalidates = true>
437 struct BooleanStateWithPtrSetVector : public BooleanState {
438 
439   bool contains(Ty *Elem) const { return Set.contains(Elem); }
440   bool insert(Ty *Elem) {
441     if (InsertInvalidates)
442       BooleanState::indicatePessimisticFixpoint();
443     return Set.insert(Elem);
444   }
445 
446   Ty *operator[](int Idx) const { return Set[Idx]; }
447   bool operator==(const BooleanStateWithPtrSetVector &RHS) const {
448     return BooleanState::operator==(RHS) && Set == RHS.Set;
449   }
450   bool operator!=(const BooleanStateWithPtrSetVector &RHS) const {
451     return !(*this == RHS);
452   }
453 
454   bool empty() const { return Set.empty(); }
455   size_t size() const { return Set.size(); }
456 
457   /// "Clamp" this state with \p RHS.
458   BooleanStateWithPtrSetVector &
459   operator^=(const BooleanStateWithPtrSetVector &RHS) {
460     BooleanState::operator^=(RHS);
461     Set.insert(RHS.Set.begin(), RHS.Set.end());
462     return *this;
463   }
464 
465 private:
466   /// A set to keep track of elements.
467   SetVector<Ty *> Set;
468 
469 public:
470   typename decltype(Set)::iterator begin() { return Set.begin(); }
471   typename decltype(Set)::iterator end() { return Set.end(); }
472   typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
473   typename decltype(Set)::const_iterator end() const { return Set.end(); }
474 };
475 
476 struct KernelInfoState : AbstractState {
477   /// Flag to track if we reached a fixpoint.
478   bool IsAtFixpoint = false;
479 
480   /// The parallel regions (identified by the outlined parallel functions) that
481   /// can be reached from the associated function.
482   BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false>
483       ReachedKnownParallelRegions;
484 
485   /// State to track what parallel region we might reach.
486   BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
487 
488   /// State to track if we are in SPMD-mode, assumed or know, and why we decided
489   /// we cannot be.
490   BooleanStateWithPtrSetVector<Instruction> SPMDCompatibilityTracker;
491 
492   /// The __kmpc_target_init call in this kernel, if any. If we find more than
493   /// one we abort as the kernel is malformed.
494   CallBase *KernelInitCB = nullptr;
495 
496   /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
497   /// one we abort as the kernel is malformed.
498   CallBase *KernelDeinitCB = nullptr;
499 
500   /// Abstract State interface
501   ///{
502 
503   KernelInfoState() {}
504   KernelInfoState(bool BestState) {
505     if (!BestState)
506       indicatePessimisticFixpoint();
507   }
508 
509   /// See AbstractState::isValidState(...)
510   bool isValidState() const override { return true; }
511 
512   /// See AbstractState::isAtFixpoint(...)
513   bool isAtFixpoint() const override { return IsAtFixpoint; }
514 
515   /// See AbstractState::indicatePessimisticFixpoint(...)
516   ChangeStatus indicatePessimisticFixpoint() override {
517     IsAtFixpoint = true;
518     SPMDCompatibilityTracker.indicatePessimisticFixpoint();
519     ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
520     return ChangeStatus::CHANGED;
521   }
522 
523   /// See AbstractState::indicateOptimisticFixpoint(...)
524   ChangeStatus indicateOptimisticFixpoint() override {
525     IsAtFixpoint = true;
526     return ChangeStatus::UNCHANGED;
527   }
528 
529   /// Return the assumed state
530   KernelInfoState &getAssumed() { return *this; }
531   const KernelInfoState &getAssumed() const { return *this; }
532 
533   bool operator==(const KernelInfoState &RHS) const {
534     if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
535       return false;
536     if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
537       return false;
538     if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
539       return false;
540     return true;
541   }
542 
543   /// Return empty set as the best state of potential values.
544   static KernelInfoState getBestState() { return KernelInfoState(true); }
545 
546   static KernelInfoState getBestState(KernelInfoState &KIS) {
547     return getBestState();
548   }
549 
550   /// Return full set as the worst state of potential values.
551   static KernelInfoState getWorstState() { return KernelInfoState(false); }
552 
553   /// "Clamp" this state with \p KIS.
554   KernelInfoState operator^=(const KernelInfoState &KIS) {
555     // Do not merge two different _init and _deinit call sites.
556     if (KIS.KernelInitCB) {
557       if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
558         indicatePessimisticFixpoint();
559       KernelInitCB = KIS.KernelInitCB;
560     }
561     if (KIS.KernelDeinitCB) {
562       if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
563         indicatePessimisticFixpoint();
564       KernelDeinitCB = KIS.KernelDeinitCB;
565     }
566     SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
567     ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
568     ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
569     return *this;
570   }
571 
572   KernelInfoState operator&=(const KernelInfoState &KIS) {
573     return (*this ^= KIS);
574   }
575 
576   ///}
577 };
578 
579 /// Used to map the values physically (in the IR) stored in an offload
580 /// array, to a vector in memory.
581 struct OffloadArray {
582   /// Physical array (in the IR).
583   AllocaInst *Array = nullptr;
584   /// Mapped values.
585   SmallVector<Value *, 8> StoredValues;
586   /// Last stores made in the offload array.
587   SmallVector<StoreInst *, 8> LastAccesses;
588 
589   OffloadArray() = default;
590 
591   /// Initializes the OffloadArray with the values stored in \p Array before
592   /// instruction \p Before is reached. Returns false if the initialization
593   /// fails.
594   /// This MUST be used immediately after the construction of the object.
595   bool initialize(AllocaInst &Array, Instruction &Before) {
596     if (!Array.getAllocatedType()->isArrayTy())
597       return false;
598 
599     if (!getValues(Array, Before))
600       return false;
601 
602     this->Array = &Array;
603     return true;
604   }
605 
606   static const unsigned DeviceIDArgNum = 1;
607   static const unsigned BasePtrsArgNum = 3;
608   static const unsigned PtrsArgNum = 4;
609   static const unsigned SizesArgNum = 5;
610 
611 private:
612   /// Traverses the BasicBlock where \p Array is, collecting the stores made to
613   /// \p Array, leaving StoredValues with the values stored before the
614   /// instruction \p Before is reached.
615   bool getValues(AllocaInst &Array, Instruction &Before) {
616     // Initialize container.
617     const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
618     StoredValues.assign(NumValues, nullptr);
619     LastAccesses.assign(NumValues, nullptr);
620 
621     // TODO: This assumes the instruction \p Before is in the same
622     //  BasicBlock as Array. Make it general, for any control flow graph.
623     BasicBlock *BB = Array.getParent();
624     if (BB != Before.getParent())
625       return false;
626 
627     const DataLayout &DL = Array.getModule()->getDataLayout();
628     const unsigned int PointerSize = DL.getPointerSize();
629 
630     for (Instruction &I : *BB) {
631       if (&I == &Before)
632         break;
633 
634       if (!isa<StoreInst>(&I))
635         continue;
636 
637       auto *S = cast<StoreInst>(&I);
638       int64_t Offset = -1;
639       auto *Dst =
640           GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
641       if (Dst == &Array) {
642         int64_t Idx = Offset / PointerSize;
643         StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
644         LastAccesses[Idx] = S;
645       }
646     }
647 
648     return isFilled();
649   }
650 
651   /// Returns true if all values in StoredValues and
652   /// LastAccesses are not nullptrs.
653   bool isFilled() {
654     const unsigned NumValues = StoredValues.size();
655     for (unsigned I = 0; I < NumValues; ++I) {
656       if (!StoredValues[I] || !LastAccesses[I])
657         return false;
658     }
659 
660     return true;
661   }
662 };
663 
664 struct OpenMPOpt {
665 
666   using OptimizationRemarkGetter =
667       function_ref<OptimizationRemarkEmitter &(Function *)>;
668 
669   OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
670             OptimizationRemarkGetter OREGetter,
671             OMPInformationCache &OMPInfoCache, Attributor &A)
672       : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
673         OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
674 
675   /// Check if any remarks are enabled for openmp-opt
676   bool remarksEnabled() {
677     auto &Ctx = M.getContext();
678     return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
679   }
680 
681   /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
682   bool run(bool IsModulePass) {
683     if (SCC.empty())
684       return false;
685 
686     bool Changed = false;
687 
688     LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
689                       << " functions in a slice with "
690                       << OMPInfoCache.ModuleSlice.size() << " functions\n");
691 
692     if (IsModulePass) {
693       Changed |= runAttributor(IsModulePass);
694 
695       // Recollect uses, in case Attributor deleted any.
696       OMPInfoCache.recollectUses();
697 
698       if (remarksEnabled())
699         analysisGlobalization();
700     } else {
701       if (PrintICVValues)
702         printICVs();
703       if (PrintOpenMPKernels)
704         printKernels();
705 
706       Changed |= runAttributor(IsModulePass);
707 
708       // Recollect uses, in case Attributor deleted any.
709       OMPInfoCache.recollectUses();
710 
711       Changed |= deleteParallelRegions();
712       Changed |= rewriteDeviceCodeStateMachine();
713 
714       if (HideMemoryTransferLatency)
715         Changed |= hideMemTransfersLatency();
716       Changed |= deduplicateRuntimeCalls();
717       if (EnableParallelRegionMerging) {
718         if (mergeParallelRegions()) {
719           deduplicateRuntimeCalls();
720           Changed = true;
721         }
722       }
723     }
724 
725     return Changed;
726   }
727 
728   /// Print initial ICV values for testing.
729   /// FIXME: This should be done from the Attributor once it is added.
730   void printICVs() const {
731     InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
732                                  ICV_proc_bind};
733 
734     for (Function *F : OMPInfoCache.ModuleSlice) {
735       for (auto ICV : ICVs) {
736         auto ICVInfo = OMPInfoCache.ICVs[ICV];
737         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
738           return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
739                      << " Value: "
740                      << (ICVInfo.InitValue
741                              ? toString(ICVInfo.InitValue->getValue(), 10, true)
742                              : "IMPLEMENTATION_DEFINED");
743         };
744 
745         emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
746       }
747     }
748   }
749 
750   /// Print OpenMP GPU kernels for testing.
751   void printKernels() const {
752     for (Function *F : SCC) {
753       if (!OMPInfoCache.Kernels.count(F))
754         continue;
755 
756       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
757         return ORA << "OpenMP GPU kernel "
758                    << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
759       };
760 
761       emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
762     }
763   }
764 
765   /// Return the call if \p U is a callee use in a regular call. If \p RFI is
766   /// given it has to be the callee or a nullptr is returned.
767   static CallInst *getCallIfRegularCall(
768       Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
769     CallInst *CI = dyn_cast<CallInst>(U.getUser());
770     if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
771         (!RFI || CI->getCalledFunction() == RFI->Declaration))
772       return CI;
773     return nullptr;
774   }
775 
776   /// Return the call if \p V is a regular call. If \p RFI is given it has to be
777   /// the callee or a nullptr is returned.
778   static CallInst *getCallIfRegularCall(
779       Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
780     CallInst *CI = dyn_cast<CallInst>(&V);
781     if (CI && !CI->hasOperandBundles() &&
782         (!RFI || CI->getCalledFunction() == RFI->Declaration))
783       return CI;
784     return nullptr;
785   }
786 
787 private:
788   /// Merge parallel regions when it is safe.
789   bool mergeParallelRegions() {
790     const unsigned CallbackCalleeOperand = 2;
791     const unsigned CallbackFirstArgOperand = 3;
792     using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
793 
794     // Check if there are any __kmpc_fork_call calls to merge.
795     OMPInformationCache::RuntimeFunctionInfo &RFI =
796         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
797 
798     if (!RFI.Declaration)
799       return false;
800 
801     // Unmergable calls that prevent merging a parallel region.
802     OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
803         OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
804         OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
805     };
806 
807     bool Changed = false;
808     LoopInfo *LI = nullptr;
809     DominatorTree *DT = nullptr;
810 
811     SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
812 
813     BasicBlock *StartBB = nullptr, *EndBB = nullptr;
814     auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
815                          BasicBlock &ContinuationIP) {
816       BasicBlock *CGStartBB = CodeGenIP.getBlock();
817       BasicBlock *CGEndBB =
818           SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
819       assert(StartBB != nullptr && "StartBB should not be null");
820       CGStartBB->getTerminator()->setSuccessor(0, StartBB);
821       assert(EndBB != nullptr && "EndBB should not be null");
822       EndBB->getTerminator()->setSuccessor(0, CGEndBB);
823     };
824 
825     auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
826                       Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
827       ReplacementValue = &Inner;
828       return CodeGenIP;
829     };
830 
831     auto FiniCB = [&](InsertPointTy CodeGenIP) {};
832 
833     /// Create a sequential execution region within a merged parallel region,
834     /// encapsulated in a master construct with a barrier for synchronization.
835     auto CreateSequentialRegion = [&](Function *OuterFn,
836                                       BasicBlock *OuterPredBB,
837                                       Instruction *SeqStartI,
838                                       Instruction *SeqEndI) {
839       // Isolate the instructions of the sequential region to a separate
840       // block.
841       BasicBlock *ParentBB = SeqStartI->getParent();
842       BasicBlock *SeqEndBB =
843           SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
844       BasicBlock *SeqAfterBB =
845           SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
846       BasicBlock *SeqStartBB =
847           SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
848 
849       assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
850              "Expected a different CFG");
851       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
852       ParentBB->getTerminator()->eraseFromParent();
853 
854       auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
855                            BasicBlock &ContinuationIP) {
856         BasicBlock *CGStartBB = CodeGenIP.getBlock();
857         BasicBlock *CGEndBB =
858             SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
859         assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
860         CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
861         assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
862         SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
863       };
864       auto FiniCB = [&](InsertPointTy CodeGenIP) {};
865 
866       // Find outputs from the sequential region to outside users and
867       // broadcast their values to them.
868       for (Instruction &I : *SeqStartBB) {
869         SmallPtrSet<Instruction *, 4> OutsideUsers;
870         for (User *Usr : I.users()) {
871           Instruction &UsrI = *cast<Instruction>(Usr);
872           // Ignore outputs to LT intrinsics, code extraction for the merged
873           // parallel region will fix them.
874           if (UsrI.isLifetimeStartOrEnd())
875             continue;
876 
877           if (UsrI.getParent() != SeqStartBB)
878             OutsideUsers.insert(&UsrI);
879         }
880 
881         if (OutsideUsers.empty())
882           continue;
883 
884         // Emit an alloca in the outer region to store the broadcasted
885         // value.
886         const DataLayout &DL = M.getDataLayout();
887         AllocaInst *AllocaI = new AllocaInst(
888             I.getType(), DL.getAllocaAddrSpace(), nullptr,
889             I.getName() + ".seq.output.alloc", &OuterFn->front().front());
890 
891         // Emit a store instruction in the sequential BB to update the
892         // value.
893         new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
894 
895         // Emit a load instruction and replace the use of the output value
896         // with it.
897         for (Instruction *UsrI : OutsideUsers) {
898           LoadInst *LoadI = new LoadInst(
899               I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
900           UsrI->replaceUsesOfWith(&I, LoadI);
901         }
902       }
903 
904       OpenMPIRBuilder::LocationDescription Loc(
905           InsertPointTy(ParentBB, ParentBB->end()), DL);
906       InsertPointTy SeqAfterIP =
907           OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
908 
909       OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
910 
911       BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
912 
913       LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
914                         << "\n");
915     };
916 
917     // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
918     // contained in BB and only separated by instructions that can be
919     // redundantly executed in parallel. The block BB is split before the first
920     // call (in MergableCIs) and after the last so the entire region we merge
921     // into a single parallel region is contained in a single basic block
922     // without any other instructions. We use the OpenMPIRBuilder to outline
923     // that block and call the resulting function via __kmpc_fork_call.
924     auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) {
925       // TODO: Change the interface to allow single CIs expanded, e.g, to
926       // include an outer loop.
927       assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
928 
929       auto Remark = [&](OptimizationRemark OR) {
930         OR << "Parallel region at "
931            << ore::NV("OpenMPParallelMergeFront",
932                       MergableCIs.front()->getDebugLoc())
933            << " merged with parallel regions at ";
934         for (auto *CI : llvm::drop_begin(MergableCIs)) {
935           OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
936           if (CI != MergableCIs.back())
937             OR << ", ";
938         }
939         return OR;
940       };
941 
942       emitRemark<OptimizationRemark>(MergableCIs.front(),
943                                      "OpenMPParallelRegionMerging", Remark);
944 
945       Function *OriginalFn = BB->getParent();
946       LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
947                         << " parallel regions in " << OriginalFn->getName()
948                         << "\n");
949 
950       // Isolate the calls to merge in a separate block.
951       EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
952       BasicBlock *AfterBB =
953           SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
954       StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
955                            "omp.par.merged");
956 
957       assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
958       const DebugLoc DL = BB->getTerminator()->getDebugLoc();
959       BB->getTerminator()->eraseFromParent();
960 
961       // Create sequential regions for sequential instructions that are
962       // in-between mergable parallel regions.
963       for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
964            It != End; ++It) {
965         Instruction *ForkCI = *It;
966         Instruction *NextForkCI = *(It + 1);
967 
968         // Continue if there are not in-between instructions.
969         if (ForkCI->getNextNode() == NextForkCI)
970           continue;
971 
972         CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
973                                NextForkCI->getPrevNode());
974       }
975 
976       OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
977                                                DL);
978       IRBuilder<>::InsertPoint AllocaIP(
979           &OriginalFn->getEntryBlock(),
980           OriginalFn->getEntryBlock().getFirstInsertionPt());
981       // Create the merged parallel region with default proc binding, to
982       // avoid overriding binding settings, and without explicit cancellation.
983       InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
984           Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
985           OMP_PROC_BIND_default, /* IsCancellable */ false);
986       BranchInst::Create(AfterBB, AfterIP.getBlock());
987 
988       // Perform the actual outlining.
989       OMPInfoCache.OMPBuilder.finalize(OriginalFn,
990                                        /* AllowExtractorSinking */ true);
991 
992       Function *OutlinedFn = MergableCIs.front()->getCaller();
993 
994       // Replace the __kmpc_fork_call calls with direct calls to the outlined
995       // callbacks.
996       SmallVector<Value *, 8> Args;
997       for (auto *CI : MergableCIs) {
998         Value *Callee =
999             CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts();
1000         FunctionType *FT =
1001             cast<FunctionType>(Callee->getType()->getPointerElementType());
1002         Args.clear();
1003         Args.push_back(OutlinedFn->getArg(0));
1004         Args.push_back(OutlinedFn->getArg(1));
1005         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
1006              U < E; ++U)
1007           Args.push_back(CI->getArgOperand(U));
1008 
1009         CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
1010         if (CI->getDebugLoc())
1011           NewCI->setDebugLoc(CI->getDebugLoc());
1012 
1013         // Forward parameter attributes from the callback to the callee.
1014         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
1015              U < E; ++U)
1016           for (const Attribute &A : CI->getAttributes().getParamAttributes(U))
1017             NewCI->addParamAttr(
1018                 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1019 
1020         // Emit an explicit barrier to replace the implicit fork-join barrier.
1021         if (CI != MergableCIs.back()) {
1022           // TODO: Remove barrier if the merged parallel region includes the
1023           // 'nowait' clause.
1024           OMPInfoCache.OMPBuilder.createBarrier(
1025               InsertPointTy(NewCI->getParent(),
1026                             NewCI->getNextNode()->getIterator()),
1027               OMPD_parallel);
1028         }
1029 
1030         auto Remark = [&](OptimizationRemark OR) {
1031           return OR << "Parallel region at "
1032                     << ore::NV("OpenMPParallelMerge", CI->getDebugLoc())
1033                     << " merged with "
1034                     << ore::NV("OpenMPParallelMergeFront",
1035                                MergableCIs.front()->getDebugLoc());
1036         };
1037         if (CI != MergableCIs.front())
1038           emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionMerging",
1039                                          Remark);
1040 
1041         CI->eraseFromParent();
1042       }
1043 
1044       assert(OutlinedFn != OriginalFn && "Outlining failed");
1045       CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1046       CGUpdater.reanalyzeFunction(*OriginalFn);
1047 
1048       NumOpenMPParallelRegionsMerged += MergableCIs.size();
1049 
1050       return true;
1051     };
1052 
1053     // Helper function that identifes sequences of
1054     // __kmpc_fork_call uses in a basic block.
1055     auto DetectPRsCB = [&](Use &U, Function &F) {
1056       CallInst *CI = getCallIfRegularCall(U, &RFI);
1057       BB2PRMap[CI->getParent()].insert(CI);
1058 
1059       return false;
1060     };
1061 
1062     BB2PRMap.clear();
1063     RFI.foreachUse(SCC, DetectPRsCB);
1064     SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1065     // Find mergable parallel regions within a basic block that are
1066     // safe to merge, that is any in-between instructions can safely
1067     // execute in parallel after merging.
1068     // TODO: support merging across basic-blocks.
1069     for (auto &It : BB2PRMap) {
1070       auto &CIs = It.getSecond();
1071       if (CIs.size() < 2)
1072         continue;
1073 
1074       BasicBlock *BB = It.getFirst();
1075       SmallVector<CallInst *, 4> MergableCIs;
1076 
1077       /// Returns true if the instruction is mergable, false otherwise.
1078       /// A terminator instruction is unmergable by definition since merging
1079       /// works within a BB. Instructions before the mergable region are
1080       /// mergable if they are not calls to OpenMP runtime functions that may
1081       /// set different execution parameters for subsequent parallel regions.
1082       /// Instructions in-between parallel regions are mergable if they are not
1083       /// calls to any non-intrinsic function since that may call a non-mergable
1084       /// OpenMP runtime function.
1085       auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1086         // We do not merge across BBs, hence return false (unmergable) if the
1087         // instruction is a terminator.
1088         if (I.isTerminator())
1089           return false;
1090 
1091         if (!isa<CallInst>(&I))
1092           return true;
1093 
1094         CallInst *CI = cast<CallInst>(&I);
1095         if (IsBeforeMergableRegion) {
1096           Function *CalledFunction = CI->getCalledFunction();
1097           if (!CalledFunction)
1098             return false;
1099           // Return false (unmergable) if the call before the parallel
1100           // region calls an explicit affinity (proc_bind) or number of
1101           // threads (num_threads) compiler-generated function. Those settings
1102           // may be incompatible with following parallel regions.
1103           // TODO: ICV tracking to detect compatibility.
1104           for (const auto &RFI : UnmergableCallsInfo) {
1105             if (CalledFunction == RFI.Declaration)
1106               return false;
1107           }
1108         } else {
1109           // Return false (unmergable) if there is a call instruction
1110           // in-between parallel regions when it is not an intrinsic. It
1111           // may call an unmergable OpenMP runtime function in its callpath.
1112           // TODO: Keep track of possible OpenMP calls in the callpath.
1113           if (!isa<IntrinsicInst>(CI))
1114             return false;
1115         }
1116 
1117         return true;
1118       };
1119       // Find maximal number of parallel region CIs that are safe to merge.
1120       for (auto It = BB->begin(), End = BB->end(); It != End;) {
1121         Instruction &I = *It;
1122         ++It;
1123 
1124         if (CIs.count(&I)) {
1125           MergableCIs.push_back(cast<CallInst>(&I));
1126           continue;
1127         }
1128 
1129         // Continue expanding if the instruction is mergable.
1130         if (IsMergable(I, MergableCIs.empty()))
1131           continue;
1132 
1133         // Forward the instruction iterator to skip the next parallel region
1134         // since there is an unmergable instruction which can affect it.
1135         for (; It != End; ++It) {
1136           Instruction &SkipI = *It;
1137           if (CIs.count(&SkipI)) {
1138             LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1139                               << " due to " << I << "\n");
1140             ++It;
1141             break;
1142           }
1143         }
1144 
1145         // Store mergable regions found.
1146         if (MergableCIs.size() > 1) {
1147           MergableCIsVector.push_back(MergableCIs);
1148           LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1149                             << " parallel regions in block " << BB->getName()
1150                             << " of function " << BB->getParent()->getName()
1151                             << "\n";);
1152         }
1153 
1154         MergableCIs.clear();
1155       }
1156 
1157       if (!MergableCIsVector.empty()) {
1158         Changed = true;
1159 
1160         for (auto &MergableCIs : MergableCIsVector)
1161           Merge(MergableCIs, BB);
1162         MergableCIsVector.clear();
1163       }
1164     }
1165 
1166     if (Changed) {
1167       /// Re-collect use for fork calls, emitted barrier calls, and
1168       /// any emitted master/end_master calls.
1169       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1170       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1171       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1172       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1173     }
1174 
1175     return Changed;
1176   }
1177 
1178   /// Try to delete parallel regions if possible.
1179   bool deleteParallelRegions() {
1180     const unsigned CallbackCalleeOperand = 2;
1181 
1182     OMPInformationCache::RuntimeFunctionInfo &RFI =
1183         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1184 
1185     if (!RFI.Declaration)
1186       return false;
1187 
1188     bool Changed = false;
1189     auto DeleteCallCB = [&](Use &U, Function &) {
1190       CallInst *CI = getCallIfRegularCall(U);
1191       if (!CI)
1192         return false;
1193       auto *Fn = dyn_cast<Function>(
1194           CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1195       if (!Fn)
1196         return false;
1197       if (!Fn->onlyReadsMemory())
1198         return false;
1199       if (!Fn->hasFnAttribute(Attribute::WillReturn))
1200         return false;
1201 
1202       LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1203                         << CI->getCaller()->getName() << "\n");
1204 
1205       auto Remark = [&](OptimizationRemark OR) {
1206         return OR << "Parallel region in "
1207                   << ore::NV("OpenMPParallelDelete", CI->getCaller()->getName())
1208                   << " deleted";
1209       };
1210       emitRemark<OptimizationRemark>(CI, "OpenMPParallelRegionDeletion",
1211                                      Remark);
1212 
1213       CGUpdater.removeCallSite(*CI);
1214       CI->eraseFromParent();
1215       Changed = true;
1216       ++NumOpenMPParallelRegionsDeleted;
1217       return true;
1218     };
1219 
1220     RFI.foreachUse(SCC, DeleteCallCB);
1221 
1222     return Changed;
1223   }
1224 
1225   /// Try to eliminate runtime calls by reusing existing ones.
1226   bool deduplicateRuntimeCalls() {
1227     bool Changed = false;
1228 
1229     RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1230         OMPRTL_omp_get_num_threads,
1231         OMPRTL_omp_in_parallel,
1232         OMPRTL_omp_get_cancellation,
1233         OMPRTL_omp_get_thread_limit,
1234         OMPRTL_omp_get_supported_active_levels,
1235         OMPRTL_omp_get_level,
1236         OMPRTL_omp_get_ancestor_thread_num,
1237         OMPRTL_omp_get_team_size,
1238         OMPRTL_omp_get_active_level,
1239         OMPRTL_omp_in_final,
1240         OMPRTL_omp_get_proc_bind,
1241         OMPRTL_omp_get_num_places,
1242         OMPRTL_omp_get_num_procs,
1243         OMPRTL_omp_get_place_num,
1244         OMPRTL_omp_get_partition_num_places,
1245         OMPRTL_omp_get_partition_place_nums};
1246 
1247     // Global-tid is handled separately.
1248     SmallSetVector<Value *, 16> GTIdArgs;
1249     collectGlobalThreadIdArguments(GTIdArgs);
1250     LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1251                       << " global thread ID arguments\n");
1252 
1253     for (Function *F : SCC) {
1254       for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1255         Changed |= deduplicateRuntimeCalls(
1256             *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1257 
1258       // __kmpc_global_thread_num is special as we can replace it with an
1259       // argument in enough cases to make it worth trying.
1260       Value *GTIdArg = nullptr;
1261       for (Argument &Arg : F->args())
1262         if (GTIdArgs.count(&Arg)) {
1263           GTIdArg = &Arg;
1264           break;
1265         }
1266       Changed |= deduplicateRuntimeCalls(
1267           *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1268     }
1269 
1270     return Changed;
1271   }
1272 
1273   /// Tries to hide the latency of runtime calls that involve host to
1274   /// device memory transfers by splitting them into their "issue" and "wait"
1275   /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1276   /// moved downards as much as possible. The "issue" issues the memory transfer
1277   /// asynchronously, returning a handle. The "wait" waits in the returned
1278   /// handle for the memory transfer to finish.
1279   bool hideMemTransfersLatency() {
1280     auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1281     bool Changed = false;
1282     auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1283       auto *RTCall = getCallIfRegularCall(U, &RFI);
1284       if (!RTCall)
1285         return false;
1286 
1287       OffloadArray OffloadArrays[3];
1288       if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1289         return false;
1290 
1291       LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1292 
1293       // TODO: Check if can be moved upwards.
1294       bool WasSplit = false;
1295       Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1296       if (WaitMovementPoint)
1297         WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1298 
1299       Changed |= WasSplit;
1300       return WasSplit;
1301     };
1302     RFI.foreachUse(SCC, SplitMemTransfers);
1303 
1304     return Changed;
1305   }
1306 
1307   void analysisGlobalization() {
1308     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1309 
1310     auto CheckGlobalization = [&](Use &U, Function &Decl) {
1311       if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1312         auto Remark = [&](OptimizationRemarkMissed ORM) {
1313           return ORM
1314                  << "Found thread data sharing on the GPU. "
1315                  << "Expect degraded performance due to data globalization.";
1316         };
1317         emitRemark<OptimizationRemarkMissed>(CI, "OpenMPGlobalization", Remark);
1318       }
1319 
1320       return false;
1321     };
1322 
1323     RFI.foreachUse(SCC, CheckGlobalization);
1324   }
1325 
1326   /// Maps the values stored in the offload arrays passed as arguments to
1327   /// \p RuntimeCall into the offload arrays in \p OAs.
1328   bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1329                                 MutableArrayRef<OffloadArray> OAs) {
1330     assert(OAs.size() == 3 && "Need space for three offload arrays!");
1331 
1332     // A runtime call that involves memory offloading looks something like:
1333     // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1334     //   i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1335     // ...)
1336     // So, the idea is to access the allocas that allocate space for these
1337     // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1338     // Therefore:
1339     // i8** %offload_baseptrs.
1340     Value *BasePtrsArg =
1341         RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1342     // i8** %offload_ptrs.
1343     Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1344     // i8** %offload_sizes.
1345     Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1346 
1347     // Get values stored in **offload_baseptrs.
1348     auto *V = getUnderlyingObject(BasePtrsArg);
1349     if (!isa<AllocaInst>(V))
1350       return false;
1351     auto *BasePtrsArray = cast<AllocaInst>(V);
1352     if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1353       return false;
1354 
1355     // Get values stored in **offload_baseptrs.
1356     V = getUnderlyingObject(PtrsArg);
1357     if (!isa<AllocaInst>(V))
1358       return false;
1359     auto *PtrsArray = cast<AllocaInst>(V);
1360     if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1361       return false;
1362 
1363     // Get values stored in **offload_sizes.
1364     V = getUnderlyingObject(SizesArg);
1365     // If it's a [constant] global array don't analyze it.
1366     if (isa<GlobalValue>(V))
1367       return isa<Constant>(V);
1368     if (!isa<AllocaInst>(V))
1369       return false;
1370 
1371     auto *SizesArray = cast<AllocaInst>(V);
1372     if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1373       return false;
1374 
1375     return true;
1376   }
1377 
1378   /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1379   /// For now this is a way to test that the function getValuesInOffloadArrays
1380   /// is working properly.
1381   /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1382   void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1383     assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1384 
1385     LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1386     std::string ValuesStr;
1387     raw_string_ostream Printer(ValuesStr);
1388     std::string Separator = " --- ";
1389 
1390     for (auto *BP : OAs[0].StoredValues) {
1391       BP->print(Printer);
1392       Printer << Separator;
1393     }
1394     LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
1395     ValuesStr.clear();
1396 
1397     for (auto *P : OAs[1].StoredValues) {
1398       P->print(Printer);
1399       Printer << Separator;
1400     }
1401     LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
1402     ValuesStr.clear();
1403 
1404     for (auto *S : OAs[2].StoredValues) {
1405       S->print(Printer);
1406       Printer << Separator;
1407     }
1408     LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
1409   }
1410 
1411   /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1412   /// moved. Returns nullptr if the movement is not possible, or not worth it.
1413   Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1414     // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1415     //  Make it traverse the CFG.
1416 
1417     Instruction *CurrentI = &RuntimeCall;
1418     bool IsWorthIt = false;
1419     while ((CurrentI = CurrentI->getNextNode())) {
1420 
1421       // TODO: Once we detect the regions to be offloaded we should use the
1422       //  alias analysis manager to check if CurrentI may modify one of
1423       //  the offloaded regions.
1424       if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1425         if (IsWorthIt)
1426           return CurrentI;
1427 
1428         return nullptr;
1429       }
1430 
1431       // FIXME: For now if we move it over anything without side effect
1432       //  is worth it.
1433       IsWorthIt = true;
1434     }
1435 
1436     // Return end of BasicBlock.
1437     return RuntimeCall.getParent()->getTerminator();
1438   }
1439 
1440   /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1441   bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1442                                Instruction &WaitMovementPoint) {
1443     // Create stack allocated handle (__tgt_async_info) at the beginning of the
1444     // function. Used for storing information of the async transfer, allowing to
1445     // wait on it later.
1446     auto &IRBuilder = OMPInfoCache.OMPBuilder;
1447     auto *F = RuntimeCall.getCaller();
1448     Instruction *FirstInst = &(F->getEntryBlock().front());
1449     AllocaInst *Handle = new AllocaInst(
1450         IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst);
1451 
1452     // Add "issue" runtime call declaration:
1453     // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1454     //   i8**, i8**, i64*, i64*)
1455     FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1456         M, OMPRTL___tgt_target_data_begin_mapper_issue);
1457 
1458     // Change RuntimeCall call site for its asynchronous version.
1459     SmallVector<Value *, 16> Args;
1460     for (auto &Arg : RuntimeCall.args())
1461       Args.push_back(Arg.get());
1462     Args.push_back(Handle);
1463 
1464     CallInst *IssueCallsite =
1465         CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
1466     RuntimeCall.eraseFromParent();
1467 
1468     // Add "wait" runtime call declaration:
1469     // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1470     FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1471         M, OMPRTL___tgt_target_data_begin_mapper_wait);
1472 
1473     Value *WaitParams[2] = {
1474         IssueCallsite->getArgOperand(
1475             OffloadArray::DeviceIDArgNum), // device_id.
1476         Handle                             // handle to wait on.
1477     };
1478     CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
1479 
1480     return true;
1481   }
1482 
1483   static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1484                                     bool GlobalOnly, bool &SingleChoice) {
1485     if (CurrentIdent == NextIdent)
1486       return CurrentIdent;
1487 
1488     // TODO: Figure out how to actually combine multiple debug locations. For
1489     //       now we just keep an existing one if there is a single choice.
1490     if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1491       SingleChoice = !CurrentIdent;
1492       return NextIdent;
1493     }
1494     return nullptr;
1495   }
1496 
1497   /// Return an `struct ident_t*` value that represents the ones used in the
1498   /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1499   /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1500   /// return value we create one from scratch. We also do not yet combine
1501   /// information, e.g., the source locations, see combinedIdentStruct.
1502   Value *
1503   getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1504                                  Function &F, bool GlobalOnly) {
1505     bool SingleChoice = true;
1506     Value *Ident = nullptr;
1507     auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1508       CallInst *CI = getCallIfRegularCall(U, &RFI);
1509       if (!CI || &F != &Caller)
1510         return false;
1511       Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1512                                   /* GlobalOnly */ true, SingleChoice);
1513       return false;
1514     };
1515     RFI.foreachUse(SCC, CombineIdentStruct);
1516 
1517     if (!Ident || !SingleChoice) {
1518       // The IRBuilder uses the insertion block to get to the module, this is
1519       // unfortunate but we work around it for now.
1520       if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1521         OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1522             &F.getEntryBlock(), F.getEntryBlock().begin()));
1523       // Create a fallback location if non was found.
1524       // TODO: Use the debug locations of the calls instead.
1525       Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr();
1526       Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc);
1527     }
1528     return Ident;
1529   }
1530 
1531   /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1532   /// \p ReplVal if given.
1533   bool deduplicateRuntimeCalls(Function &F,
1534                                OMPInformationCache::RuntimeFunctionInfo &RFI,
1535                                Value *ReplVal = nullptr) {
1536     auto *UV = RFI.getUseVector(F);
1537     if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1538       return false;
1539 
1540     LLVM_DEBUG(
1541         dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1542                << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1543 
1544     assert((!ReplVal || (isa<Argument>(ReplVal) &&
1545                          cast<Argument>(ReplVal)->getParent() == &F)) &&
1546            "Unexpected replacement value!");
1547 
1548     // TODO: Use dominance to find a good position instead.
1549     auto CanBeMoved = [this](CallBase &CB) {
1550       unsigned NumArgs = CB.getNumArgOperands();
1551       if (NumArgs == 0)
1552         return true;
1553       if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1554         return false;
1555       for (unsigned u = 1; u < NumArgs; ++u)
1556         if (isa<Instruction>(CB.getArgOperand(u)))
1557           return false;
1558       return true;
1559     };
1560 
1561     if (!ReplVal) {
1562       for (Use *U : *UV)
1563         if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1564           if (!CanBeMoved(*CI))
1565             continue;
1566 
1567           auto Remark = [&](OptimizationRemark OR) {
1568             return OR << "OpenMP runtime call "
1569                       << ore::NV("OpenMPOptRuntime", RFI.Name)
1570                       << " moved to beginning of OpenMP region";
1571           };
1572           emitRemark<OptimizationRemark>(&F, "OpenMPRuntimeCodeMotion", Remark);
1573 
1574           CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
1575           ReplVal = CI;
1576           break;
1577         }
1578       if (!ReplVal)
1579         return false;
1580     }
1581 
1582     // If we use a call as a replacement value we need to make sure the ident is
1583     // valid at the new location. For now we just pick a global one, either
1584     // existing and used by one of the calls, or created from scratch.
1585     if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1586       if (CI->getNumArgOperands() > 0 &&
1587           CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1588         Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1589                                                       /* GlobalOnly */ true);
1590         CI->setArgOperand(0, Ident);
1591       }
1592     }
1593 
1594     bool Changed = false;
1595     auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1596       CallInst *CI = getCallIfRegularCall(U, &RFI);
1597       if (!CI || CI == ReplVal || &F != &Caller)
1598         return false;
1599       assert(CI->getCaller() == &F && "Unexpected call!");
1600 
1601       auto Remark = [&](OptimizationRemark OR) {
1602         return OR << "OpenMP runtime call "
1603                   << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated";
1604       };
1605       emitRemark<OptimizationRemark>(&F, "OpenMPRuntimeDeduplicated", Remark);
1606 
1607       CGUpdater.removeCallSite(*CI);
1608       CI->replaceAllUsesWith(ReplVal);
1609       CI->eraseFromParent();
1610       ++NumOpenMPRuntimeCallsDeduplicated;
1611       Changed = true;
1612       return true;
1613     };
1614     RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1615 
1616     return Changed;
1617   }
1618 
1619   /// Collect arguments that represent the global thread id in \p GTIdArgs.
1620   void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1621     // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1622     //       initialization. We could define an AbstractAttribute instead and
1623     //       run the Attributor here once it can be run as an SCC pass.
1624 
1625     // Helper to check the argument \p ArgNo at all call sites of \p F for
1626     // a GTId.
1627     auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1628       if (!F.hasLocalLinkage())
1629         return false;
1630       for (Use &U : F.uses()) {
1631         if (CallInst *CI = getCallIfRegularCall(U)) {
1632           Value *ArgOp = CI->getArgOperand(ArgNo);
1633           if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1634               getCallIfRegularCall(
1635                   *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1636             continue;
1637         }
1638         return false;
1639       }
1640       return true;
1641     };
1642 
1643     // Helper to identify uses of a GTId as GTId arguments.
1644     auto AddUserArgs = [&](Value &GTId) {
1645       for (Use &U : GTId.uses())
1646         if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1647           if (CI->isArgOperand(&U))
1648             if (Function *Callee = CI->getCalledFunction())
1649               if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1650                 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1651     };
1652 
1653     // The argument users of __kmpc_global_thread_num calls are GTIds.
1654     OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1655         OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1656 
1657     GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1658       if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1659         AddUserArgs(*CI);
1660       return false;
1661     });
1662 
1663     // Transitively search for more arguments by looking at the users of the
1664     // ones we know already. During the search the GTIdArgs vector is extended
1665     // so we cannot cache the size nor can we use a range based for.
1666     for (unsigned u = 0; u < GTIdArgs.size(); ++u)
1667       AddUserArgs(*GTIdArgs[u]);
1668   }
1669 
1670   /// Kernel (=GPU) optimizations and utility functions
1671   ///
1672   ///{{
1673 
1674   /// Check if \p F is a kernel, hence entry point for target offloading.
1675   bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
1676 
1677   /// Cache to remember the unique kernel for a function.
1678   DenseMap<Function *, Optional<Kernel>> UniqueKernelMap;
1679 
1680   /// Find the unique kernel that will execute \p F, if any.
1681   Kernel getUniqueKernelFor(Function &F);
1682 
1683   /// Find the unique kernel that will execute \p I, if any.
1684   Kernel getUniqueKernelFor(Instruction &I) {
1685     return getUniqueKernelFor(*I.getFunction());
1686   }
1687 
1688   /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1689   /// the cases we can avoid taking the address of a function.
1690   bool rewriteDeviceCodeStateMachine();
1691 
1692   ///
1693   ///}}
1694 
1695   /// Emit a remark generically
1696   ///
1697   /// This template function can be used to generically emit a remark. The
1698   /// RemarkKind should be one of the following:
1699   ///   - OptimizationRemark to indicate a successful optimization attempt
1700   ///   - OptimizationRemarkMissed to report a failed optimization attempt
1701   ///   - OptimizationRemarkAnalysis to provide additional information about an
1702   ///     optimization attempt
1703   ///
1704   /// The remark is built using a callback function provided by the caller that
1705   /// takes a RemarkKind as input and returns a RemarkKind.
1706   template <typename RemarkKind, typename RemarkCallBack>
1707   void emitRemark(Instruction *I, StringRef RemarkName,
1708                   RemarkCallBack &&RemarkCB) const {
1709     Function *F = I->getParent()->getParent();
1710     auto &ORE = OREGetter(F);
1711 
1712     ORE.emit([&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
1713   }
1714 
1715   /// Emit a remark on a function.
1716   template <typename RemarkKind, typename RemarkCallBack>
1717   void emitRemark(Function *F, StringRef RemarkName,
1718                   RemarkCallBack &&RemarkCB) const {
1719     auto &ORE = OREGetter(F);
1720 
1721     ORE.emit([&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
1722   }
1723 
1724   /// The underlying module.
1725   Module &M;
1726 
1727   /// The SCC we are operating on.
1728   SmallVectorImpl<Function *> &SCC;
1729 
1730   /// Callback to update the call graph, the first argument is a removed call,
1731   /// the second an optional replacement call.
1732   CallGraphUpdater &CGUpdater;
1733 
1734   /// Callback to get an OptimizationRemarkEmitter from a Function *
1735   OptimizationRemarkGetter OREGetter;
1736 
1737   /// OpenMP-specific information cache. Also Used for Attributor runs.
1738   OMPInformationCache &OMPInfoCache;
1739 
1740   /// Attributor instance.
1741   Attributor &A;
1742 
1743   /// Helper function to run Attributor on SCC.
1744   bool runAttributor(bool IsModulePass) {
1745     if (SCC.empty())
1746       return false;
1747 
1748     registerAAs(IsModulePass);
1749 
1750     ChangeStatus Changed = A.run();
1751 
1752     LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
1753                       << " functions, result: " << Changed << ".\n");
1754 
1755     return Changed == ChangeStatus::CHANGED;
1756   }
1757 
1758   /// Populate the Attributor with abstract attribute opportunities in the
1759   /// function.
1760   void registerAAs(bool IsModulePass);
1761 };
1762 
1763 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
1764   if (!OMPInfoCache.ModuleSlice.count(&F))
1765     return nullptr;
1766 
1767   // Use a scope to keep the lifetime of the CachedKernel short.
1768   {
1769     Optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
1770     if (CachedKernel)
1771       return *CachedKernel;
1772 
1773     // TODO: We should use an AA to create an (optimistic and callback
1774     //       call-aware) call graph. For now we stick to simple patterns that
1775     //       are less powerful, basically the worst fixpoint.
1776     if (isKernel(F)) {
1777       CachedKernel = Kernel(&F);
1778       return *CachedKernel;
1779     }
1780 
1781     CachedKernel = nullptr;
1782     if (!F.hasLocalLinkage()) {
1783 
1784       // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
1785       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1786         return ORA
1787                << "[OMP100] Potentially unknown OpenMP target region caller";
1788       };
1789       emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
1790 
1791       return nullptr;
1792     }
1793   }
1794 
1795   auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
1796     if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
1797       // Allow use in equality comparisons.
1798       if (Cmp->isEquality())
1799         return getUniqueKernelFor(*Cmp);
1800       return nullptr;
1801     }
1802     if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
1803       // Allow direct calls.
1804       if (CB->isCallee(&U))
1805         return getUniqueKernelFor(*CB);
1806 
1807       OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1808           OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1809       // Allow the use in __kmpc_parallel_51 calls.
1810       if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
1811         return getUniqueKernelFor(*CB);
1812       return nullptr;
1813     }
1814     // Disallow every other use.
1815     return nullptr;
1816   };
1817 
1818   // TODO: In the future we want to track more than just a unique kernel.
1819   SmallPtrSet<Kernel, 2> PotentialKernels;
1820   OMPInformationCache::foreachUse(F, [&](const Use &U) {
1821     PotentialKernels.insert(GetUniqueKernelForUse(U));
1822   });
1823 
1824   Kernel K = nullptr;
1825   if (PotentialKernels.size() == 1)
1826     K = *PotentialKernels.begin();
1827 
1828   // Cache the result.
1829   UniqueKernelMap[&F] = K;
1830 
1831   return K;
1832 }
1833 
1834 bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
1835   OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1836       OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1837 
1838   bool Changed = false;
1839   if (!KernelParallelRFI)
1840     return Changed;
1841 
1842   for (Function *F : SCC) {
1843 
1844     // Check if the function is a use in a __kmpc_parallel_51 call at
1845     // all.
1846     bool UnknownUse = false;
1847     bool KernelParallelUse = false;
1848     unsigned NumDirectCalls = 0;
1849 
1850     SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
1851     OMPInformationCache::foreachUse(*F, [&](Use &U) {
1852       if (auto *CB = dyn_cast<CallBase>(U.getUser()))
1853         if (CB->isCallee(&U)) {
1854           ++NumDirectCalls;
1855           return;
1856         }
1857 
1858       if (isa<ICmpInst>(U.getUser())) {
1859         ToBeReplacedStateMachineUses.push_back(&U);
1860         return;
1861       }
1862 
1863       // Find wrapper functions that represent parallel kernels.
1864       CallInst *CI =
1865           OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
1866       const unsigned int WrapperFunctionArgNo = 6;
1867       if (!KernelParallelUse && CI &&
1868           CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
1869         KernelParallelUse = true;
1870         ToBeReplacedStateMachineUses.push_back(&U);
1871         return;
1872       }
1873       UnknownUse = true;
1874     });
1875 
1876     // Do not emit a remark if we haven't seen a __kmpc_parallel_51
1877     // use.
1878     if (!KernelParallelUse)
1879       continue;
1880 
1881     {
1882       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1883         return ORA << "Found a parallel region that is called in a target "
1884                       "region but not part of a combined target construct nor "
1885                       "nested inside a target construct without intermediate "
1886                       "code. This can lead to excessive register usage for "
1887                       "unrelated target regions in the same translation unit "
1888                       "due to spurious call edges assumed by ptxas.";
1889       };
1890       emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPParallelRegionInNonSPMD",
1891                                              Remark);
1892     }
1893 
1894     // If this ever hits, we should investigate.
1895     // TODO: Checking the number of uses is not a necessary restriction and
1896     // should be lifted.
1897     if (UnknownUse || NumDirectCalls != 1 ||
1898         ToBeReplacedStateMachineUses.size() > 2) {
1899       {
1900         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1901           return ORA << "Parallel region is used in "
1902                      << (UnknownUse ? "unknown" : "unexpected")
1903                      << " ways; will not attempt to rewrite the state machine.";
1904         };
1905         emitRemark<OptimizationRemarkAnalysis>(
1906             F, "OpenMPParallelRegionInNonSPMD", Remark);
1907       }
1908       continue;
1909     }
1910 
1911     // Even if we have __kmpc_parallel_51 calls, we (for now) give
1912     // up if the function is not called from a unique kernel.
1913     Kernel K = getUniqueKernelFor(*F);
1914     if (!K) {
1915       {
1916         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1917           return ORA << "Parallel region is not known to be called from a "
1918                         "unique single target region, maybe the surrounding "
1919                         "function has external linkage?; will not attempt to "
1920                         "rewrite the state machine use.";
1921         };
1922         emitRemark<OptimizationRemarkAnalysis>(
1923             F, "OpenMPParallelRegionInMultipleKernesl", Remark);
1924       }
1925       continue;
1926     }
1927 
1928     // We now know F is a parallel body function called only from the kernel K.
1929     // We also identified the state machine uses in which we replace the
1930     // function pointer by a new global symbol for identification purposes. This
1931     // ensures only direct calls to the function are left.
1932 
1933     {
1934       auto RemarkParalleRegion = [&](OptimizationRemarkAnalysis ORA) {
1935         return ORA << "Specialize parallel region that is only reached from a "
1936                       "single target region to avoid spurious call edges and "
1937                       "excessive register usage in other target regions. "
1938                       "(parallel region ID: "
1939                    << ore::NV("OpenMPParallelRegion", F->getName())
1940                    << ", kernel ID: "
1941                    << ore::NV("OpenMPTargetRegion", K->getName()) << ")";
1942       };
1943       emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPParallelRegionInNonSPMD",
1944                                              RemarkParalleRegion);
1945       auto RemarkKernel = [&](OptimizationRemarkAnalysis ORA) {
1946         return ORA << "Target region containing the parallel region that is "
1947                       "specialized. (parallel region ID: "
1948                    << ore::NV("OpenMPParallelRegion", F->getName())
1949                    << ", kernel ID: "
1950                    << ore::NV("OpenMPTargetRegion", K->getName()) << ")";
1951       };
1952       emitRemark<OptimizationRemarkAnalysis>(K, "OpenMPParallelRegionInNonSPMD",
1953                                              RemarkKernel);
1954     }
1955 
1956     Module &M = *F->getParent();
1957     Type *Int8Ty = Type::getInt8Ty(M.getContext());
1958 
1959     auto *ID = new GlobalVariable(
1960         M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
1961         UndefValue::get(Int8Ty), F->getName() + ".ID");
1962 
1963     for (Use *U : ToBeReplacedStateMachineUses)
1964       U->set(ConstantExpr::getBitCast(ID, U->get()->getType()));
1965 
1966     ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
1967 
1968     Changed = true;
1969   }
1970 
1971   return Changed;
1972 }
1973 
1974 /// Abstract Attribute for tracking ICV values.
1975 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
1976   using Base = StateWrapper<BooleanState, AbstractAttribute>;
1977   AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
1978 
1979   void initialize(Attributor &A) override {
1980     Function *F = getAnchorScope();
1981     if (!F || !A.isFunctionIPOAmendable(*F))
1982       indicatePessimisticFixpoint();
1983   }
1984 
1985   /// Returns true if value is assumed to be tracked.
1986   bool isAssumedTracked() const { return getAssumed(); }
1987 
1988   /// Returns true if value is known to be tracked.
1989   bool isKnownTracked() const { return getAssumed(); }
1990 
1991   /// Create an abstract attribute biew for the position \p IRP.
1992   static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
1993 
1994   /// Return the value with which \p I can be replaced for specific \p ICV.
1995   virtual Optional<Value *> getReplacementValue(InternalControlVar ICV,
1996                                                 const Instruction *I,
1997                                                 Attributor &A) const {
1998     return None;
1999   }
2000 
2001   /// Return an assumed unique ICV value if a single candidate is found. If
2002   /// there cannot be one, return a nullptr. If it is not clear yet, return the
2003   /// Optional::NoneType.
2004   virtual Optional<Value *>
2005   getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2006 
2007   // Currently only nthreads is being tracked.
2008   // this array will only grow with time.
2009   InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2010 
2011   /// See AbstractAttribute::getName()
2012   const std::string getName() const override { return "AAICVTracker"; }
2013 
2014   /// See AbstractAttribute::getIdAddr()
2015   const char *getIdAddr() const override { return &ID; }
2016 
2017   /// This function should return true if the type of the \p AA is AAICVTracker
2018   static bool classof(const AbstractAttribute *AA) {
2019     return (AA->getIdAddr() == &ID);
2020   }
2021 
2022   static const char ID;
2023 };
2024 
2025 struct AAICVTrackerFunction : public AAICVTracker {
2026   AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2027       : AAICVTracker(IRP, A) {}
2028 
2029   // FIXME: come up with better string.
2030   const std::string getAsStr() const override { return "ICVTrackerFunction"; }
2031 
2032   // FIXME: come up with some stats.
2033   void trackStatistics() const override {}
2034 
2035   /// We don't manifest anything for this AA.
2036   ChangeStatus manifest(Attributor &A) override {
2037     return ChangeStatus::UNCHANGED;
2038   }
2039 
2040   // Map of ICV to their values at specific program point.
2041   EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
2042                   InternalControlVar::ICV___last>
2043       ICVReplacementValuesMap;
2044 
2045   ChangeStatus updateImpl(Attributor &A) override {
2046     ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2047 
2048     Function *F = getAnchorScope();
2049 
2050     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2051 
2052     for (InternalControlVar ICV : TrackableICVs) {
2053       auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2054 
2055       auto &ValuesMap = ICVReplacementValuesMap[ICV];
2056       auto TrackValues = [&](Use &U, Function &) {
2057         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2058         if (!CI)
2059           return false;
2060 
2061         // FIXME: handle setters with more that 1 arguments.
2062         /// Track new value.
2063         if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2064           HasChanged = ChangeStatus::CHANGED;
2065 
2066         return false;
2067       };
2068 
2069       auto CallCheck = [&](Instruction &I) {
2070         Optional<Value *> ReplVal = getValueForCall(A, &I, ICV);
2071         if (ReplVal.hasValue() &&
2072             ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2073           HasChanged = ChangeStatus::CHANGED;
2074 
2075         return true;
2076       };
2077 
2078       // Track all changes of an ICV.
2079       SetterRFI.foreachUse(TrackValues, F);
2080 
2081       bool UsedAssumedInformation = false;
2082       A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2083                                 UsedAssumedInformation,
2084                                 /* CheckBBLivenessOnly */ true);
2085 
2086       /// TODO: Figure out a way to avoid adding entry in
2087       /// ICVReplacementValuesMap
2088       Instruction *Entry = &F->getEntryBlock().front();
2089       if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2090         ValuesMap.insert(std::make_pair(Entry, nullptr));
2091     }
2092 
2093     return HasChanged;
2094   }
2095 
2096   /// Hepler to check if \p I is a call and get the value for it if it is
2097   /// unique.
2098   Optional<Value *> getValueForCall(Attributor &A, const Instruction *I,
2099                                     InternalControlVar &ICV) const {
2100 
2101     const auto *CB = dyn_cast<CallBase>(I);
2102     if (!CB || CB->hasFnAttr("no_openmp") ||
2103         CB->hasFnAttr("no_openmp_routines"))
2104       return None;
2105 
2106     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2107     auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2108     auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2109     Function *CalledFunction = CB->getCalledFunction();
2110 
2111     // Indirect call, assume ICV changes.
2112     if (CalledFunction == nullptr)
2113       return nullptr;
2114     if (CalledFunction == GetterRFI.Declaration)
2115       return None;
2116     if (CalledFunction == SetterRFI.Declaration) {
2117       if (ICVReplacementValuesMap[ICV].count(I))
2118         return ICVReplacementValuesMap[ICV].lookup(I);
2119 
2120       return nullptr;
2121     }
2122 
2123     // Since we don't know, assume it changes the ICV.
2124     if (CalledFunction->isDeclaration())
2125       return nullptr;
2126 
2127     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2128         *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2129 
2130     if (ICVTrackingAA.isAssumedTracked())
2131       return ICVTrackingAA.getUniqueReplacementValue(ICV);
2132 
2133     // If we don't know, assume it changes.
2134     return nullptr;
2135   }
2136 
2137   // We don't check unique value for a function, so return None.
2138   Optional<Value *>
2139   getUniqueReplacementValue(InternalControlVar ICV) const override {
2140     return None;
2141   }
2142 
2143   /// Return the value with which \p I can be replaced for specific \p ICV.
2144   Optional<Value *> getReplacementValue(InternalControlVar ICV,
2145                                         const Instruction *I,
2146                                         Attributor &A) const override {
2147     const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2148     if (ValuesMap.count(I))
2149       return ValuesMap.lookup(I);
2150 
2151     SmallVector<const Instruction *, 16> Worklist;
2152     SmallPtrSet<const Instruction *, 16> Visited;
2153     Worklist.push_back(I);
2154 
2155     Optional<Value *> ReplVal;
2156 
2157     while (!Worklist.empty()) {
2158       const Instruction *CurrInst = Worklist.pop_back_val();
2159       if (!Visited.insert(CurrInst).second)
2160         continue;
2161 
2162       const BasicBlock *CurrBB = CurrInst->getParent();
2163 
2164       // Go up and look for all potential setters/calls that might change the
2165       // ICV.
2166       while ((CurrInst = CurrInst->getPrevNode())) {
2167         if (ValuesMap.count(CurrInst)) {
2168           Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2169           // Unknown value, track new.
2170           if (!ReplVal.hasValue()) {
2171             ReplVal = NewReplVal;
2172             break;
2173           }
2174 
2175           // If we found a new value, we can't know the icv value anymore.
2176           if (NewReplVal.hasValue())
2177             if (ReplVal != NewReplVal)
2178               return nullptr;
2179 
2180           break;
2181         }
2182 
2183         Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV);
2184         if (!NewReplVal.hasValue())
2185           continue;
2186 
2187         // Unknown value, track new.
2188         if (!ReplVal.hasValue()) {
2189           ReplVal = NewReplVal;
2190           break;
2191         }
2192 
2193         // if (NewReplVal.hasValue())
2194         // We found a new value, we can't know the icv value anymore.
2195         if (ReplVal != NewReplVal)
2196           return nullptr;
2197       }
2198 
2199       // If we are in the same BB and we have a value, we are done.
2200       if (CurrBB == I->getParent() && ReplVal.hasValue())
2201         return ReplVal;
2202 
2203       // Go through all predecessors and add terminators for analysis.
2204       for (const BasicBlock *Pred : predecessors(CurrBB))
2205         if (const Instruction *Terminator = Pred->getTerminator())
2206           Worklist.push_back(Terminator);
2207     }
2208 
2209     return ReplVal;
2210   }
2211 };
2212 
2213 struct AAICVTrackerFunctionReturned : AAICVTracker {
2214   AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2215       : AAICVTracker(IRP, A) {}
2216 
2217   // FIXME: come up with better string.
2218   const std::string getAsStr() const override {
2219     return "ICVTrackerFunctionReturned";
2220   }
2221 
2222   // FIXME: come up with some stats.
2223   void trackStatistics() const override {}
2224 
2225   /// We don't manifest anything for this AA.
2226   ChangeStatus manifest(Attributor &A) override {
2227     return ChangeStatus::UNCHANGED;
2228   }
2229 
2230   // Map of ICV to their values at specific program point.
2231   EnumeratedArray<Optional<Value *>, InternalControlVar,
2232                   InternalControlVar::ICV___last>
2233       ICVReplacementValuesMap;
2234 
2235   /// Return the value with which \p I can be replaced for specific \p ICV.
2236   Optional<Value *>
2237   getUniqueReplacementValue(InternalControlVar ICV) const override {
2238     return ICVReplacementValuesMap[ICV];
2239   }
2240 
2241   ChangeStatus updateImpl(Attributor &A) override {
2242     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2243     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2244         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2245 
2246     if (!ICVTrackingAA.isAssumedTracked())
2247       return indicatePessimisticFixpoint();
2248 
2249     for (InternalControlVar ICV : TrackableICVs) {
2250       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2251       Optional<Value *> UniqueICVValue;
2252 
2253       auto CheckReturnInst = [&](Instruction &I) {
2254         Optional<Value *> NewReplVal =
2255             ICVTrackingAA.getReplacementValue(ICV, &I, A);
2256 
2257         // If we found a second ICV value there is no unique returned value.
2258         if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal)
2259           return false;
2260 
2261         UniqueICVValue = NewReplVal;
2262 
2263         return true;
2264       };
2265 
2266       bool UsedAssumedInformation = false;
2267       if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2268                                      UsedAssumedInformation,
2269                                      /* CheckBBLivenessOnly */ true))
2270         UniqueICVValue = nullptr;
2271 
2272       if (UniqueICVValue == ReplVal)
2273         continue;
2274 
2275       ReplVal = UniqueICVValue;
2276       Changed = ChangeStatus::CHANGED;
2277     }
2278 
2279     return Changed;
2280   }
2281 };
2282 
2283 struct AAICVTrackerCallSite : AAICVTracker {
2284   AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2285       : AAICVTracker(IRP, A) {}
2286 
2287   void initialize(Attributor &A) override {
2288     Function *F = getAnchorScope();
2289     if (!F || !A.isFunctionIPOAmendable(*F))
2290       indicatePessimisticFixpoint();
2291 
2292     // We only initialize this AA for getters, so we need to know which ICV it
2293     // gets.
2294     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2295     for (InternalControlVar ICV : TrackableICVs) {
2296       auto ICVInfo = OMPInfoCache.ICVs[ICV];
2297       auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2298       if (Getter.Declaration == getAssociatedFunction()) {
2299         AssociatedICV = ICVInfo.Kind;
2300         return;
2301       }
2302     }
2303 
2304     /// Unknown ICV.
2305     indicatePessimisticFixpoint();
2306   }
2307 
2308   ChangeStatus manifest(Attributor &A) override {
2309     if (!ReplVal.hasValue() || !ReplVal.getValue())
2310       return ChangeStatus::UNCHANGED;
2311 
2312     A.changeValueAfterManifest(*getCtxI(), **ReplVal);
2313     A.deleteAfterManifest(*getCtxI());
2314 
2315     return ChangeStatus::CHANGED;
2316   }
2317 
2318   // FIXME: come up with better string.
2319   const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
2320 
2321   // FIXME: come up with some stats.
2322   void trackStatistics() const override {}
2323 
2324   InternalControlVar AssociatedICV;
2325   Optional<Value *> ReplVal;
2326 
2327   ChangeStatus updateImpl(Attributor &A) override {
2328     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2329         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2330 
2331     // We don't have any information, so we assume it changes the ICV.
2332     if (!ICVTrackingAA.isAssumedTracked())
2333       return indicatePessimisticFixpoint();
2334 
2335     Optional<Value *> NewReplVal =
2336         ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
2337 
2338     if (ReplVal == NewReplVal)
2339       return ChangeStatus::UNCHANGED;
2340 
2341     ReplVal = NewReplVal;
2342     return ChangeStatus::CHANGED;
2343   }
2344 
2345   // Return the value with which associated value can be replaced for specific
2346   // \p ICV.
2347   Optional<Value *>
2348   getUniqueReplacementValue(InternalControlVar ICV) const override {
2349     return ReplVal;
2350   }
2351 };
2352 
2353 struct AAICVTrackerCallSiteReturned : AAICVTracker {
2354   AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2355       : AAICVTracker(IRP, A) {}
2356 
2357   // FIXME: come up with better string.
2358   const std::string getAsStr() const override {
2359     return "ICVTrackerCallSiteReturned";
2360   }
2361 
2362   // FIXME: come up with some stats.
2363   void trackStatistics() const override {}
2364 
2365   /// We don't manifest anything for this AA.
2366   ChangeStatus manifest(Attributor &A) override {
2367     return ChangeStatus::UNCHANGED;
2368   }
2369 
2370   // Map of ICV to their values at specific program point.
2371   EnumeratedArray<Optional<Value *>, InternalControlVar,
2372                   InternalControlVar::ICV___last>
2373       ICVReplacementValuesMap;
2374 
2375   /// Return the value with which associated value can be replaced for specific
2376   /// \p ICV.
2377   Optional<Value *>
2378   getUniqueReplacementValue(InternalControlVar ICV) const override {
2379     return ICVReplacementValuesMap[ICV];
2380   }
2381 
2382   ChangeStatus updateImpl(Attributor &A) override {
2383     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2384     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2385         *this, IRPosition::returned(*getAssociatedFunction()),
2386         DepClassTy::REQUIRED);
2387 
2388     // We don't have any information, so we assume it changes the ICV.
2389     if (!ICVTrackingAA.isAssumedTracked())
2390       return indicatePessimisticFixpoint();
2391 
2392     for (InternalControlVar ICV : TrackableICVs) {
2393       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2394       Optional<Value *> NewReplVal =
2395           ICVTrackingAA.getUniqueReplacementValue(ICV);
2396 
2397       if (ReplVal == NewReplVal)
2398         continue;
2399 
2400       ReplVal = NewReplVal;
2401       Changed = ChangeStatus::CHANGED;
2402     }
2403     return Changed;
2404   }
2405 };
2406 
2407 struct AAExecutionDomainFunction : public AAExecutionDomain {
2408   AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2409       : AAExecutionDomain(IRP, A) {}
2410 
2411   const std::string getAsStr() const override {
2412     return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) +
2413            "/" + std::to_string(NumBBs) + " BBs thread 0 only.";
2414   }
2415 
2416   /// See AbstractAttribute::trackStatistics().
2417   void trackStatistics() const override {}
2418 
2419   void initialize(Attributor &A) override {
2420     Function *F = getAnchorScope();
2421     for (const auto &BB : *F)
2422       SingleThreadedBBs.insert(&BB);
2423     NumBBs = SingleThreadedBBs.size();
2424   }
2425 
2426   ChangeStatus manifest(Attributor &A) override {
2427     LLVM_DEBUG({
2428       for (const BasicBlock *BB : SingleThreadedBBs)
2429         dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2430                << BB->getName() << " is executed by a single thread.\n";
2431     });
2432     return ChangeStatus::UNCHANGED;
2433   }
2434 
2435   ChangeStatus updateImpl(Attributor &A) override;
2436 
2437   /// Check if an instruction is executed by a single thread.
2438   bool isExecutedByInitialThreadOnly(const Instruction &I) const override {
2439     return isExecutedByInitialThreadOnly(*I.getParent());
2440   }
2441 
2442   bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2443     return isValidState() && SingleThreadedBBs.contains(&BB);
2444   }
2445 
2446   /// Set of basic blocks that are executed by a single thread.
2447   DenseSet<const BasicBlock *> SingleThreadedBBs;
2448 
2449   /// Total number of basic blocks in this function.
2450   long unsigned NumBBs;
2451 };
2452 
2453 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
2454   Function *F = getAnchorScope();
2455   ReversePostOrderTraversal<Function *> RPOT(F);
2456   auto NumSingleThreadedBBs = SingleThreadedBBs.size();
2457 
2458   bool AllCallSitesKnown;
2459   auto PredForCallSite = [&](AbstractCallSite ACS) {
2460     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
2461         *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
2462         DepClassTy::REQUIRED);
2463     return ACS.isDirectCall() &&
2464            ExecutionDomainAA.isExecutedByInitialThreadOnly(
2465                *ACS.getInstruction());
2466   };
2467 
2468   if (!A.checkForAllCallSites(PredForCallSite, *this,
2469                               /* RequiresAllCallSites */ true,
2470                               AllCallSitesKnown))
2471     SingleThreadedBBs.erase(&F->getEntryBlock());
2472 
2473   auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2474   auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2475 
2476   // Check if the edge into the successor block compares the __kmpc_target_init
2477   // result with -1. If we are in non-SPMD-mode that signals only the main
2478   // thread will execute the edge.
2479   auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
2480     if (!Edge || !Edge->isConditional())
2481       return false;
2482     if (Edge->getSuccessor(0) != SuccessorBB)
2483       return false;
2484 
2485     auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2486     if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2487       return false;
2488 
2489     ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2490     if (!C)
2491       return false;
2492 
2493     // Match:  -1 == __kmpc_target_init (for non-SPMD kernels only!)
2494     if (C->isAllOnesValue()) {
2495       auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2496       if (!CB || CB->getCalledFunction() != RFI.Declaration)
2497         return false;
2498       const int InitIsSPMDArgNo = 1;
2499       auto *IsSPMDModeCI =
2500           dyn_cast<ConstantInt>(CB->getOperand(InitIsSPMDArgNo));
2501       return IsSPMDModeCI && IsSPMDModeCI->isZero();
2502     }
2503 
2504     return false;
2505   };
2506 
2507   // Merge all the predecessor states into the current basic block. A basic
2508   // block is executed by a single thread if all of its predecessors are.
2509   auto MergePredecessorStates = [&](BasicBlock *BB) {
2510     if (pred_begin(BB) == pred_end(BB))
2511       return SingleThreadedBBs.contains(BB);
2512 
2513     bool IsInitialThread = true;
2514     for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB);
2515          PredBB != PredEndBB; ++PredBB) {
2516       if (!IsInitialThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()),
2517                                BB))
2518         IsInitialThread &= SingleThreadedBBs.contains(*PredBB);
2519     }
2520 
2521     return IsInitialThread;
2522   };
2523 
2524   for (auto *BB : RPOT) {
2525     if (!MergePredecessorStates(BB))
2526       SingleThreadedBBs.erase(BB);
2527   }
2528 
2529   return (NumSingleThreadedBBs == SingleThreadedBBs.size())
2530              ? ChangeStatus::UNCHANGED
2531              : ChangeStatus::CHANGED;
2532 }
2533 
2534 /// Try to replace memory allocation calls called by a single thread with a
2535 /// static buffer of shared memory.
2536 struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
2537   using Base = StateWrapper<BooleanState, AbstractAttribute>;
2538   AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2539 
2540   /// Create an abstract attribute view for the position \p IRP.
2541   static AAHeapToShared &createForPosition(const IRPosition &IRP,
2542                                            Attributor &A);
2543 
2544   /// See AbstractAttribute::getName().
2545   const std::string getName() const override { return "AAHeapToShared"; }
2546 
2547   /// See AbstractAttribute::getIdAddr().
2548   const char *getIdAddr() const override { return &ID; }
2549 
2550   /// This function should return true if the type of the \p AA is
2551   /// AAHeapToShared.
2552   static bool classof(const AbstractAttribute *AA) {
2553     return (AA->getIdAddr() == &ID);
2554   }
2555 
2556   /// Unique ID (due to the unique address)
2557   static const char ID;
2558 };
2559 
2560 struct AAHeapToSharedFunction : public AAHeapToShared {
2561   AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
2562       : AAHeapToShared(IRP, A) {}
2563 
2564   const std::string getAsStr() const override {
2565     return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
2566            " malloc calls eligible.";
2567   }
2568 
2569   /// See AbstractAttribute::trackStatistics().
2570   void trackStatistics() const override {}
2571 
2572   void initialize(Attributor &A) override {
2573     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2574     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2575 
2576     for (User *U : RFI.Declaration->users())
2577       if (CallBase *CB = dyn_cast<CallBase>(U))
2578         MallocCalls.insert(CB);
2579   }
2580 
2581   ChangeStatus manifest(Attributor &A) override {
2582     if (MallocCalls.empty())
2583       return ChangeStatus::UNCHANGED;
2584 
2585     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2586     auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2587 
2588     Function *F = getAnchorScope();
2589     auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
2590                                             DepClassTy::OPTIONAL);
2591 
2592     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2593     for (CallBase *CB : MallocCalls) {
2594       // Skip replacing this if HeapToStack has already claimed it.
2595       if (HS && HS->isAssumedHeapToStack(*CB))
2596         continue;
2597 
2598       // Find the unique free call to remove it.
2599       SmallVector<CallBase *, 4> FreeCalls;
2600       for (auto *U : CB->users()) {
2601         CallBase *C = dyn_cast<CallBase>(U);
2602         if (C && C->getCalledFunction() == FreeCall.Declaration)
2603           FreeCalls.push_back(C);
2604       }
2605       if (FreeCalls.size() != 1)
2606         continue;
2607 
2608       ConstantInt *AllocSize = dyn_cast<ConstantInt>(CB->getArgOperand(0));
2609 
2610       LLVM_DEBUG(dbgs() << TAG << "Replace globalization call in "
2611                         << CB->getCaller()->getName() << " with "
2612                         << AllocSize->getZExtValue()
2613                         << " bytes of shared memory\n");
2614 
2615       // Create a new shared memory buffer of the same size as the allocation
2616       // and replace all the uses of the original allocation with it.
2617       Module *M = CB->getModule();
2618       Type *Int8Ty = Type::getInt8Ty(M->getContext());
2619       Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
2620       auto *SharedMem = new GlobalVariable(
2621           *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
2622           UndefValue::get(Int8ArrTy), CB->getName(), nullptr,
2623           GlobalValue::NotThreadLocal,
2624           static_cast<unsigned>(AddressSpace::Shared));
2625       auto *NewBuffer =
2626           ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
2627 
2628       auto Remark = [&](OptimizationRemark OR) {
2629         return OR << "Replaced globalized variable with "
2630                   << ore::NV("SharedMemory", AllocSize->getZExtValue())
2631                   << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ")
2632                   << "of shared memory";
2633       };
2634       A.emitRemark<OptimizationRemark>(CB, "OpenMPReplaceGlobalization",
2635                                        Remark);
2636 
2637       SharedMem->setAlignment(MaybeAlign(32));
2638 
2639       A.changeValueAfterManifest(*CB, *NewBuffer);
2640       A.deleteAfterManifest(*CB);
2641       A.deleteAfterManifest(*FreeCalls.front());
2642 
2643       NumBytesMovedToSharedMemory += AllocSize->getZExtValue();
2644       Changed = ChangeStatus::CHANGED;
2645     }
2646 
2647     return Changed;
2648   }
2649 
2650   ChangeStatus updateImpl(Attributor &A) override {
2651     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2652     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2653     Function *F = getAnchorScope();
2654 
2655     auto NumMallocCalls = MallocCalls.size();
2656 
2657     // Only consider malloc calls executed by a single thread with a constant.
2658     for (User *U : RFI.Declaration->users()) {
2659       const auto &ED = A.getAAFor<AAExecutionDomain>(
2660           *this, IRPosition::function(*F), DepClassTy::REQUIRED);
2661       if (CallBase *CB = dyn_cast<CallBase>(U))
2662         if (!dyn_cast<ConstantInt>(CB->getArgOperand(0)) ||
2663             !ED.isExecutedByInitialThreadOnly(*CB))
2664           MallocCalls.erase(CB);
2665     }
2666 
2667     if (NumMallocCalls != MallocCalls.size())
2668       return ChangeStatus::CHANGED;
2669 
2670     return ChangeStatus::UNCHANGED;
2671   }
2672 
2673   /// Collection of all malloc calls in a function.
2674   SmallPtrSet<CallBase *, 4> MallocCalls;
2675 };
2676 
2677 struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
2678   using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
2679   AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2680 
2681   /// Statistics are tracked as part of manifest for now.
2682   void trackStatistics() const override {}
2683 
2684   /// See AbstractAttribute::getAsStr()
2685   const std::string getAsStr() const override {
2686     if (!isValidState())
2687       return "<invalid>";
2688     return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
2689                                                             : "generic") +
2690            std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
2691                                                                : "") +
2692            std::string(" #PRs: ") +
2693            std::to_string(ReachedKnownParallelRegions.size()) +
2694            ", #Unknown PRs: " +
2695            std::to_string(ReachedUnknownParallelRegions.size());
2696   }
2697 
2698   /// Create an abstract attribute biew for the position \p IRP.
2699   static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
2700 
2701   /// See AbstractAttribute::getName()
2702   const std::string getName() const override { return "AAKernelInfo"; }
2703 
2704   /// See AbstractAttribute::getIdAddr()
2705   const char *getIdAddr() const override { return &ID; }
2706 
2707   /// This function should return true if the type of the \p AA is AAKernelInfo
2708   static bool classof(const AbstractAttribute *AA) {
2709     return (AA->getIdAddr() == &ID);
2710   }
2711 
2712   static const char ID;
2713 };
2714 
2715 /// The function kernel info abstract attribute, basically, what can we say
2716 /// about a function with regards to the KernelInfoState.
2717 struct AAKernelInfoFunction : AAKernelInfo {
2718   AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
2719       : AAKernelInfo(IRP, A) {}
2720 
2721   /// See AbstractAttribute::initialize(...).
2722   void initialize(Attributor &A) override {
2723     // This is a high-level transform that might change the constant arguments
2724     // of the init and dinit calls. We need to tell the Attributor about this
2725     // to avoid other parts using the current constant value for simpliication.
2726     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2727 
2728     Function *Fn = getAnchorScope();
2729     if (!OMPInfoCache.Kernels.count(Fn))
2730       return;
2731 
2732     OMPInformationCache::RuntimeFunctionInfo &InitRFI =
2733         OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2734     OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
2735         OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
2736 
2737     // For kernels we perform more initialization work, first we find the init
2738     // and deinit calls.
2739     auto StoreCallBase = [](Use &U,
2740                             OMPInformationCache::RuntimeFunctionInfo &RFI,
2741                             CallBase *&Storage) {
2742       CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
2743       assert(CB &&
2744              "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
2745       assert(!Storage &&
2746              "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
2747       Storage = CB;
2748       return false;
2749     };
2750     InitRFI.foreachUse(
2751         [&](Use &U, Function &) {
2752           StoreCallBase(U, InitRFI, KernelInitCB);
2753           return false;
2754         },
2755         Fn);
2756     DeinitRFI.foreachUse(
2757         [&](Use &U, Function &) {
2758           StoreCallBase(U, DeinitRFI, KernelDeinitCB);
2759           return false;
2760         },
2761         Fn);
2762 
2763     assert((KernelInitCB && KernelDeinitCB) &&
2764            "Kernel without __kmpc_target_init or __kmpc_target_deinit!");
2765 
2766     // For kernels we might need to initialize/finalize the IsSPMD state and
2767     // we need to register a simplification callback so that the Attributor
2768     // knows the constant arguments to __kmpc_target_init and
2769     // __kmpc_target_deinit might actually change.
2770 
2771     Attributor::SimplifictionCallbackTy StateMachineSimplifyCB =
2772         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2773             bool &UsedAssumedInformation) -> Optional<Value *> {
2774       // IRP represents the "use generic state machine" argument of an
2775       // __kmpc_target_init call. We will answer this one with the internal
2776       // state. As long as we are not in an invalid state, we will create a
2777       // custom state machine so the value should be a `i1 false`. If we are
2778       // in an invalid state, we won't change the value that is in the IR.
2779       if (!isValidState())
2780         return nullptr;
2781       if (AA)
2782         A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2783       UsedAssumedInformation = !isAtFixpoint();
2784       auto *FalseVal =
2785           ConstantInt::getBool(IRP.getAnchorValue().getContext(), 0);
2786       return FalseVal;
2787     };
2788 
2789     Attributor::SimplifictionCallbackTy IsSPMDModeSimplifyCB =
2790         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2791             bool &UsedAssumedInformation) -> Optional<Value *> {
2792       // IRP represents the "SPMDCompatibilityTracker" argument of an
2793       // __kmpc_target_init or
2794       // __kmpc_target_deinit call. We will answer this one with the internal
2795       // state.
2796       if (!isValidState())
2797         return nullptr;
2798       if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2799         if (AA)
2800           A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2801         UsedAssumedInformation = true;
2802       } else {
2803         UsedAssumedInformation = false;
2804       }
2805       auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
2806                                        SPMDCompatibilityTracker.isAssumed());
2807       return Val;
2808     };
2809 
2810     constexpr const int InitIsSPMDArgNo = 1;
2811     constexpr const int DeinitIsSPMDArgNo = 1;
2812     constexpr const int InitUseStateMachineArgNo = 2;
2813     A.registerSimplificationCallback(
2814         IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo),
2815         StateMachineSimplifyCB);
2816     A.registerSimplificationCallback(
2817         IRPosition::callsite_argument(*KernelInitCB, InitIsSPMDArgNo),
2818         IsSPMDModeSimplifyCB);
2819     A.registerSimplificationCallback(
2820         IRPosition::callsite_argument(*KernelDeinitCB, DeinitIsSPMDArgNo),
2821         IsSPMDModeSimplifyCB);
2822 
2823     // Check if we know we are in SPMD-mode already.
2824     ConstantInt *IsSPMDArg =
2825         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo));
2826     if (IsSPMDArg && !IsSPMDArg->isZero())
2827       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
2828   }
2829 
2830   /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
2831   /// finished now.
2832   ChangeStatus manifest(Attributor &A) override {
2833     // If we are not looking at a kernel with __kmpc_target_init and
2834     // __kmpc_target_deinit call we cannot actually manifest the information.
2835     if (!KernelInitCB || !KernelDeinitCB)
2836       return ChangeStatus::UNCHANGED;
2837 
2838     // Known SPMD-mode kernels need no manifest changes.
2839     if (SPMDCompatibilityTracker.isKnown())
2840       return ChangeStatus::UNCHANGED;
2841 
2842     // If we can we change the execution mode to SPMD-mode otherwise we build a
2843     // custom state machine.
2844     if (!changeToSPMDMode(A))
2845       buildCustomStateMachine(A);
2846 
2847     return ChangeStatus::CHANGED;
2848   }
2849 
2850   bool changeToSPMDMode(Attributor &A) {
2851     if (!SPMDCompatibilityTracker.isAssumed()) {
2852       for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
2853         if (!NonCompatibleI)
2854           continue;
2855         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2856           ORA << "Kernel will be executed in generic-mode due to this "
2857                  "potential side-effect";
2858           if (auto *CI = dyn_cast<CallBase>(NonCompatibleI)) {
2859             if (Function *F = CI->getCalledFunction())
2860               ORA << ", consider to add "
2861                      "`__attribute__((assume(\"ompx_spmd_amenable\")))`"
2862                      " to the called function '"
2863                   << F->getName() << "'";
2864           }
2865           return ORA << ".";
2866         };
2867         A.emitRemark<OptimizationRemarkAnalysis>(
2868             NonCompatibleI, "OpenMPKernelNonSPMDMode", Remark);
2869 
2870         LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
2871                           << *NonCompatibleI << "\n");
2872       }
2873 
2874       return false;
2875     }
2876 
2877     // Adjust the global exec mode flag that tells the runtime what mode this
2878     // kernel is executed in.
2879     Function *Kernel = getAnchorScope();
2880     GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable(
2881         (Kernel->getName() + "_exec_mode").str());
2882     assert(ExecMode && "Kernel without exec mode?");
2883     assert(ExecMode->getInitializer() &&
2884            ExecMode->getInitializer()->isOneValue() &&
2885            "Initially non-SPMD kernel has SPMD exec mode!");
2886     ExecMode->setInitializer(
2887         ConstantInt::get(ExecMode->getInitializer()->getType(), 0));
2888 
2889     // Next rewrite the init and deinit calls to indicate we use SPMD-mode now.
2890     const int InitIsSPMDArgNo = 1;
2891     const int DeinitIsSPMDArgNo = 1;
2892     const int InitUseStateMachineArgNo = 2;
2893 
2894     auto &Ctx = getAnchorValue().getContext();
2895     A.changeUseAfterManifest(KernelInitCB->getArgOperandUse(InitIsSPMDArgNo),
2896                              *ConstantInt::getBool(Ctx, 1));
2897     A.changeUseAfterManifest(
2898         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
2899         *ConstantInt::getBool(Ctx, 0));
2900     A.changeUseAfterManifest(
2901         KernelDeinitCB->getArgOperandUse(DeinitIsSPMDArgNo),
2902         *ConstantInt::getBool(Ctx, 1));
2903     ++NumOpenMPTargetRegionKernelsSPMD;
2904 
2905     auto Remark = [&](OptimizationRemark OR) {
2906       return OR << "Generic-mode kernel is changed to SPMD-mode.";
2907     };
2908     A.emitRemark<OptimizationRemark>(KernelInitCB, "OpenMPKernelSPMDMode",
2909                                      Remark);
2910     return true;
2911   };
2912 
2913   ChangeStatus buildCustomStateMachine(Attributor &A) {
2914     assert(ReachedKnownParallelRegions.isValidState() &&
2915            "Custom state machine with invalid parallel region states?");
2916 
2917     const int InitIsSPMDArgNo = 1;
2918     const int InitUseStateMachineArgNo = 2;
2919 
2920     // Check if the current configuration is non-SPMD and generic state machine.
2921     // If we already have SPMD mode or a custom state machine we do not need to
2922     // go any further. If it is anything but a constant something is weird and
2923     // we give up.
2924     ConstantInt *UseStateMachine = dyn_cast<ConstantInt>(
2925         KernelInitCB->getArgOperand(InitUseStateMachineArgNo));
2926     ConstantInt *IsSPMD =
2927         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo));
2928 
2929     // If we are stuck with generic mode, try to create a custom device (=GPU)
2930     // state machine which is specialized for the parallel regions that are
2931     // reachable by the kernel.
2932     if (!UseStateMachine || UseStateMachine->isZero() || !IsSPMD ||
2933         !IsSPMD->isZero())
2934       return ChangeStatus::UNCHANGED;
2935 
2936     // If not SPMD mode, indicate we use a custom state machine now.
2937     auto &Ctx = getAnchorValue().getContext();
2938     auto *FalseVal = ConstantInt::getBool(Ctx, 0);
2939     A.changeUseAfterManifest(
2940         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal);
2941 
2942     // If we don't actually need a state machine we are done here. This can
2943     // happen if there simply are no parallel regions. In the resulting kernel
2944     // all worker threads will simply exit right away, leaving the main thread
2945     // to do the work alone.
2946     if (ReachedKnownParallelRegions.empty() &&
2947         ReachedUnknownParallelRegions.empty()) {
2948       ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
2949 
2950       auto Remark = [&](OptimizationRemark OR) {
2951         return OR << "Generic-mode kernel is executed without state machine "
2952                      "(good)";
2953       };
2954       A.emitRemark<OptimizationRemark>(
2955           KernelInitCB, "OpenMPKernelWithoutStateMachine", Remark);
2956 
2957       return ChangeStatus::CHANGED;
2958     }
2959 
2960     // Keep track in the statistics of our new shiny custom state machine.
2961     if (ReachedUnknownParallelRegions.empty()) {
2962       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
2963 
2964       auto Remark = [&](OptimizationRemark OR) {
2965         return OR << "Generic-mode kernel is executed with a customized state "
2966                      "machine ["
2967                   << ore::NV("ParallelRegions",
2968                              ReachedKnownParallelRegions.size())
2969                   << " known parallel regions] (good).";
2970       };
2971       A.emitRemark<OptimizationRemark>(
2972           KernelInitCB, "OpenMPKernelWithCustomizedStateMachine", Remark);
2973     } else {
2974       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
2975 
2976       auto Remark = [&](OptimizationRemark OR) {
2977         return OR << "Generic-mode kernel is executed with a customized state "
2978                      "machine that requires a fallback ["
2979                   << ore::NV("ParallelRegions",
2980                              ReachedKnownParallelRegions.size())
2981                   << " known parallel regions, "
2982                   << ore::NV("UnknownParallelRegions",
2983                              ReachedUnknownParallelRegions.size())
2984                   << " unkown parallel regions] (bad).";
2985       };
2986       A.emitRemark<OptimizationRemark>(
2987           KernelInitCB, "OpenMPKernelWithCustomizedStateMachineAndFallback",
2988           Remark);
2989 
2990       // Tell the user why we ended up with a fallback.
2991       for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
2992         if (!UnknownParallelRegionCB)
2993           continue;
2994         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2995           return ORA
2996                  << "State machine fallback caused by this call. If it is a "
2997                     "false positive, use "
2998                     "`__attribute__((assume(\"omp_no_openmp\")))` "
2999                     "(or \"omp_no_parallelism\").";
3000         };
3001         A.emitRemark<OptimizationRemarkAnalysis>(
3002             UnknownParallelRegionCB,
3003             "OpenMPKernelWithCustomizedStateMachineAndFallback", Remark);
3004       }
3005     }
3006 
3007     // Create all the blocks:
3008     //
3009     //                       InitCB = __kmpc_target_init(...)
3010     //                       bool IsWorker = InitCB >= 0;
3011     //                       if (IsWorker) {
3012     // SMBeginBB:               __kmpc_barrier_simple_spmd(...);
3013     //                         void *WorkFn;
3014     //                         bool Active = __kmpc_kernel_parallel(&WorkFn);
3015     //                         if (!WorkFn) return;
3016     // SMIsActiveCheckBB:       if (Active) {
3017     // SMIfCascadeCurrentBB:      if      (WorkFn == <ParFn0>)
3018     //                              ParFn0(...);
3019     // SMIfCascadeCurrentBB:      else if (WorkFn == <ParFn1>)
3020     //                              ParFn1(...);
3021     //                            ...
3022     // SMIfCascadeCurrentBB:      else
3023     //                              ((WorkFnTy*)WorkFn)(...);
3024     // SMEndParallelBB:           __kmpc_kernel_end_parallel(...);
3025     //                          }
3026     // SMDoneBB:                __kmpc_barrier_simple_spmd(...);
3027     //                          goto SMBeginBB;
3028     //                       }
3029     // UserCodeEntryBB:      // user code
3030     //                       __kmpc_target_deinit(...)
3031     //
3032     Function *Kernel = getAssociatedFunction();
3033     assert(Kernel && "Expected an associated function!");
3034 
3035     BasicBlock *InitBB = KernelInitCB->getParent();
3036     BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
3037         KernelInitCB->getNextNode(), "thread.user_code.check");
3038     BasicBlock *StateMachineBeginBB = BasicBlock::Create(
3039         Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
3040     BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
3041         Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
3042     BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
3043         Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
3044     BasicBlock *StateMachineIfCascadeCurrentBB =
3045         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3046                            Kernel, UserCodeEntryBB);
3047     BasicBlock *StateMachineEndParallelBB =
3048         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
3049                            Kernel, UserCodeEntryBB);
3050     BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
3051         Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
3052 
3053     const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
3054     ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
3055 
3056     InitBB->getTerminator()->eraseFromParent();
3057     Instruction *IsWorker =
3058         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
3059                          ConstantInt::get(KernelInitCB->getType(), -1),
3060                          "thread.is_worker", InitBB);
3061     IsWorker->setDebugLoc(DLoc);
3062     BranchInst::Create(StateMachineBeginBB, UserCodeEntryBB, IsWorker, InitBB);
3063 
3064     // Create local storage for the work function pointer.
3065     Type *VoidPtrTy = Type::getInt8PtrTy(Ctx);
3066     AllocaInst *WorkFnAI = new AllocaInst(VoidPtrTy, 0, "worker.work_fn.addr",
3067                                           &Kernel->getEntryBlock().front());
3068     WorkFnAI->setDebugLoc(DLoc);
3069 
3070     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3071     OMPInfoCache.OMPBuilder.updateToLocation(
3072         OpenMPIRBuilder::LocationDescription(
3073             IRBuilder<>::InsertPoint(StateMachineBeginBB,
3074                                      StateMachineBeginBB->end()),
3075             DLoc));
3076 
3077     Value *Ident = KernelInitCB->getArgOperand(0);
3078     Value *GTid = KernelInitCB;
3079 
3080     Module &M = *Kernel->getParent();
3081     FunctionCallee BarrierFn =
3082         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3083             M, OMPRTL___kmpc_barrier_simple_spmd);
3084     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB)
3085         ->setDebugLoc(DLoc);
3086 
3087     FunctionCallee KernelParallelFn =
3088         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3089             M, OMPRTL___kmpc_kernel_parallel);
3090     Instruction *IsActiveWorker = CallInst::Create(
3091         KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
3092     IsActiveWorker->setDebugLoc(DLoc);
3093     Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
3094                                        StateMachineBeginBB);
3095     WorkFn->setDebugLoc(DLoc);
3096 
3097     FunctionType *ParallelRegionFnTy = FunctionType::get(
3098         Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
3099         false);
3100     Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast(
3101         WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast",
3102         StateMachineBeginBB);
3103 
3104     Instruction *IsDone =
3105         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
3106                          Constant::getNullValue(VoidPtrTy), "worker.is_done",
3107                          StateMachineBeginBB);
3108     IsDone->setDebugLoc(DLoc);
3109     BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
3110                        IsDone, StateMachineBeginBB)
3111         ->setDebugLoc(DLoc);
3112 
3113     BranchInst::Create(StateMachineIfCascadeCurrentBB,
3114                        StateMachineDoneBarrierBB, IsActiveWorker,
3115                        StateMachineIsActiveCheckBB)
3116         ->setDebugLoc(DLoc);
3117 
3118     Value *ZeroArg =
3119         Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
3120 
3121     // Now that we have most of the CFG skeleton it is time for the if-cascade
3122     // that checks the function pointer we got from the runtime against the
3123     // parallel regions we expect, if there are any.
3124     for (int i = 0, e = ReachedKnownParallelRegions.size(); i < e; ++i) {
3125       auto *ParallelRegion = ReachedKnownParallelRegions[i];
3126       BasicBlock *PRExecuteBB = BasicBlock::Create(
3127           Ctx, "worker_state_machine.parallel_region.execute", Kernel,
3128           StateMachineEndParallelBB);
3129       CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
3130           ->setDebugLoc(DLoc);
3131       BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
3132           ->setDebugLoc(DLoc);
3133 
3134       BasicBlock *PRNextBB =
3135           BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3136                              Kernel, StateMachineEndParallelBB);
3137 
3138       // Check if we need to compare the pointer at all or if we can just
3139       // call the parallel region function.
3140       Value *IsPR;
3141       if (i + 1 < e || !ReachedUnknownParallelRegions.empty()) {
3142         Instruction *CmpI = ICmpInst::Create(
3143             ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion,
3144             "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
3145         CmpI->setDebugLoc(DLoc);
3146         IsPR = CmpI;
3147       } else {
3148         IsPR = ConstantInt::getTrue(Ctx);
3149       }
3150 
3151       BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
3152                          StateMachineIfCascadeCurrentBB)
3153           ->setDebugLoc(DLoc);
3154       StateMachineIfCascadeCurrentBB = PRNextBB;
3155     }
3156 
3157     // At the end of the if-cascade we place the indirect function pointer call
3158     // in case we might need it, that is if there can be parallel regions we
3159     // have not handled in the if-cascade above.
3160     if (!ReachedUnknownParallelRegions.empty()) {
3161       StateMachineIfCascadeCurrentBB->setName(
3162           "worker_state_machine.parallel_region.fallback.execute");
3163       CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "",
3164                        StateMachineIfCascadeCurrentBB)
3165           ->setDebugLoc(DLoc);
3166     }
3167     BranchInst::Create(StateMachineEndParallelBB,
3168                        StateMachineIfCascadeCurrentBB)
3169         ->setDebugLoc(DLoc);
3170 
3171     CallInst::Create(OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3172                          M, OMPRTL___kmpc_kernel_end_parallel),
3173                      {}, "", StateMachineEndParallelBB)
3174         ->setDebugLoc(DLoc);
3175     BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
3176         ->setDebugLoc(DLoc);
3177 
3178     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
3179         ->setDebugLoc(DLoc);
3180     BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
3181         ->setDebugLoc(DLoc);
3182 
3183     return ChangeStatus::CHANGED;
3184   }
3185 
3186   /// Fixpoint iteration update function. Will be called every time a dependence
3187   /// changed its state (and in the beginning).
3188   ChangeStatus updateImpl(Attributor &A) override {
3189     KernelInfoState StateBefore = getState();
3190 
3191     // Callback to check a read/write instruction.
3192     auto CheckRWInst = [&](Instruction &I) {
3193       // We handle calls later.
3194       if (isa<CallBase>(I))
3195         return true;
3196       // We only care about write effects.
3197       if (!I.mayWriteToMemory())
3198         return true;
3199       if (auto *SI = dyn_cast<StoreInst>(&I)) {
3200         SmallVector<const Value *> Objects;
3201         getUnderlyingObjects(SI->getPointerOperand(), Objects);
3202         if (llvm::all_of(Objects,
3203                          [](const Value *Obj) { return isa<AllocaInst>(Obj); }))
3204           return true;
3205       }
3206       // For now we give up on everything but stores.
3207       SPMDCompatibilityTracker.insert(&I);
3208       return true;
3209     };
3210 
3211     bool UsedAssumedInformationInCheckRWInst = false;
3212     if (!A.checkForAllReadWriteInstructions(
3213             CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
3214       SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3215 
3216     // Callback to check a call instruction.
3217     auto CheckCallInst = [&](Instruction &I) {
3218       auto &CB = cast<CallBase>(I);
3219       auto &CBAA = A.getAAFor<AAKernelInfo>(
3220           *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
3221       if (CBAA.getState().isValidState())
3222         getState() ^= CBAA.getState();
3223       return true;
3224     };
3225 
3226     bool UsedAssumedInformationInCheckCallInst = false;
3227     if (!A.checkForAllCallLikeInstructions(
3228             CheckCallInst, *this, UsedAssumedInformationInCheckCallInst))
3229       return indicatePessimisticFixpoint();
3230 
3231     return StateBefore == getState() ? ChangeStatus::UNCHANGED
3232                                      : ChangeStatus::CHANGED;
3233   }
3234 };
3235 
3236 /// The call site kernel info abstract attribute, basically, what can we say
3237 /// about a call site with regards to the KernelInfoState. For now this simply
3238 /// forwards the information from the callee.
3239 struct AAKernelInfoCallSite : AAKernelInfo {
3240   AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
3241       : AAKernelInfo(IRP, A) {}
3242 
3243   /// See AbstractAttribute::initialize(...).
3244   void initialize(Attributor &A) override {
3245     AAKernelInfo::initialize(A);
3246 
3247     CallBase &CB = cast<CallBase>(getAssociatedValue());
3248     Function *Callee = getAssociatedFunction();
3249 
3250     // Helper to lookup an assumption string.
3251     auto HasAssumption = [](Function *Fn, StringRef AssumptionStr) {
3252       return Fn && hasAssumption(*Fn, AssumptionStr);
3253     };
3254 
3255     // Check for SPMD-mode assumptions.
3256     if (HasAssumption(Callee, "ompx_spmd_amenable"))
3257       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3258 
3259     // First weed out calls we do not care about, that is readonly/readnone
3260     // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
3261     // parallel region or anything else we are looking for.
3262     if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
3263       indicateOptimisticFixpoint();
3264       return;
3265     }
3266 
3267     // Next we check if we know the callee. If it is a known OpenMP function
3268     // we will handle them explicitly in the switch below. If it is not, we
3269     // will use an AAKernelInfo object on the callee to gather information and
3270     // merge that into the current state. The latter happens in the updateImpl.
3271     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3272     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
3273     if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
3274       // Unknown caller or declarations are not analyzable, we give up.
3275       if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
3276 
3277         // Unknown callees might contain parallel regions, except if they have
3278         // an appropriate assumption attached.
3279         if (!(HasAssumption(Callee, "omp_no_openmp") ||
3280               HasAssumption(Callee, "omp_no_parallelism")))
3281           ReachedUnknownParallelRegions.insert(&CB);
3282 
3283         // If SPMDCompatibilityTracker is not fixed, we need to give up on the
3284         // idea we can run something unknown in SPMD-mode.
3285         if (!SPMDCompatibilityTracker.isAtFixpoint())
3286           SPMDCompatibilityTracker.insert(&CB);
3287 
3288         // We have updated the state for this unknown call properly, there won't
3289         // be any change so we indicate a fixpoint.
3290         indicateOptimisticFixpoint();
3291       }
3292       // If the callee is known and can be used in IPO, we will update the state
3293       // based on the callee state in updateImpl.
3294       return;
3295     }
3296 
3297     const unsigned int WrapperFunctionArgNo = 6;
3298     RuntimeFunction RF = It->getSecond();
3299     switch (RF) {
3300     // All the functions we know are compatible with SPMD mode.
3301     case OMPRTL___kmpc_is_spmd_exec_mode:
3302     case OMPRTL___kmpc_for_static_fini:
3303     case OMPRTL___kmpc_global_thread_num:
3304     case OMPRTL___kmpc_single:
3305     case OMPRTL___kmpc_end_single:
3306     case OMPRTL___kmpc_master:
3307     case OMPRTL___kmpc_end_master:
3308     case OMPRTL___kmpc_barrier:
3309       break;
3310     case OMPRTL___kmpc_for_static_init_4:
3311     case OMPRTL___kmpc_for_static_init_4u:
3312     case OMPRTL___kmpc_for_static_init_8:
3313     case OMPRTL___kmpc_for_static_init_8u: {
3314       // Check the schedule and allow static schedule in SPMD mode.
3315       unsigned ScheduleArgOpNo = 2;
3316       auto *ScheduleTypeCI =
3317           dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
3318       unsigned ScheduleTypeVal =
3319           ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
3320       switch (OMPScheduleType(ScheduleTypeVal)) {
3321       case OMPScheduleType::Static:
3322       case OMPScheduleType::StaticChunked:
3323       case OMPScheduleType::Distribute:
3324       case OMPScheduleType::DistributeChunked:
3325         break;
3326       default:
3327         SPMDCompatibilityTracker.insert(&CB);
3328         break;
3329       };
3330     } break;
3331     case OMPRTL___kmpc_target_init:
3332       KernelInitCB = &CB;
3333       break;
3334     case OMPRTL___kmpc_target_deinit:
3335       KernelDeinitCB = &CB;
3336       break;
3337     case OMPRTL___kmpc_parallel_51:
3338       if (auto *ParallelRegion = dyn_cast<Function>(
3339               CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) {
3340         ReachedKnownParallelRegions.insert(ParallelRegion);
3341         break;
3342       }
3343       // The condition above should usually get the parallel region function
3344       // pointer and record it. In the off chance it doesn't we assume the
3345       // worst.
3346       ReachedUnknownParallelRegions.insert(&CB);
3347       break;
3348     case OMPRTL___kmpc_omp_task:
3349       // We do not look into tasks right now, just give up.
3350       SPMDCompatibilityTracker.insert(&CB);
3351       ReachedUnknownParallelRegions.insert(&CB);
3352       break;
3353     default:
3354       // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
3355       // generally.
3356       SPMDCompatibilityTracker.insert(&CB);
3357       break;
3358     }
3359     // All other OpenMP runtime calls will not reach parallel regions so they
3360     // can be safely ignored for now. Since it is a known OpenMP runtime call we
3361     // have now modeled all effects and there is no need for any update.
3362     indicateOptimisticFixpoint();
3363   }
3364 
3365   ChangeStatus updateImpl(Attributor &A) override {
3366     // TODO: Once we have call site specific value information we can provide
3367     //       call site specific liveness information and then it makes
3368     //       sense to specialize attributes for call sites arguments instead of
3369     //       redirecting requests to the callee argument.
3370     Function *F = getAssociatedFunction();
3371     const IRPosition &FnPos = IRPosition::function(*F);
3372     auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
3373     if (getState() == FnAA.getState())
3374       return ChangeStatus::UNCHANGED;
3375     getState() = FnAA.getState();
3376     return ChangeStatus::CHANGED;
3377   }
3378 };
3379 
3380 } // namespace
3381 
3382 void OpenMPOpt::registerAAs(bool IsModulePass) {
3383   if (SCC.empty())
3384 
3385     return;
3386   if (IsModulePass) {
3387     // Ensure we create the AAKernelInfo AAs first and without triggering an
3388     // update. This will make sure we register all value simplification
3389     // callbacks before any other AA has the chance to create an AAValueSimplify
3390     // or similar.
3391     for (Function *Kernel : OMPInfoCache.Kernels)
3392       A.getOrCreateAAFor<AAKernelInfo>(
3393           IRPosition::function(*Kernel), /* QueryingAA */ nullptr,
3394           DepClassTy::NONE, /* ForceUpdate */ false,
3395           /* UpdateAfterInit */ false);
3396   }
3397 
3398   // Create CallSite AA for all Getters.
3399   for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
3400     auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
3401 
3402     auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
3403 
3404     auto CreateAA = [&](Use &U, Function &Caller) {
3405       CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
3406       if (!CI)
3407         return false;
3408 
3409       auto &CB = cast<CallBase>(*CI);
3410 
3411       IRPosition CBPos = IRPosition::callsite_function(CB);
3412       A.getOrCreateAAFor<AAICVTracker>(CBPos);
3413       return false;
3414     };
3415 
3416     GetterRFI.foreachUse(SCC, CreateAA);
3417   }
3418   auto &GlobalizationRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3419   auto CreateAA = [&](Use &U, Function &F) {
3420     A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
3421     return false;
3422   };
3423   GlobalizationRFI.foreachUse(SCC, CreateAA);
3424 
3425   // Create an ExecutionDomain AA for every function and a HeapToStack AA for
3426   // every function if there is a device kernel.
3427   for (auto *F : SCC) {
3428     if (!F->isDeclaration())
3429       A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F));
3430     if (isOpenMPDevice(M))
3431       A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F));
3432   }
3433 }
3434 
3435 const char AAICVTracker::ID = 0;
3436 const char AAKernelInfo::ID = 0;
3437 const char AAExecutionDomain::ID = 0;
3438 const char AAHeapToShared::ID = 0;
3439 
3440 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
3441                                               Attributor &A) {
3442   AAICVTracker *AA = nullptr;
3443   switch (IRP.getPositionKind()) {
3444   case IRPosition::IRP_INVALID:
3445   case IRPosition::IRP_FLOAT:
3446   case IRPosition::IRP_ARGUMENT:
3447   case IRPosition::IRP_CALL_SITE_ARGUMENT:
3448     llvm_unreachable("ICVTracker can only be created for function position!");
3449   case IRPosition::IRP_RETURNED:
3450     AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
3451     break;
3452   case IRPosition::IRP_CALL_SITE_RETURNED:
3453     AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
3454     break;
3455   case IRPosition::IRP_CALL_SITE:
3456     AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
3457     break;
3458   case IRPosition::IRP_FUNCTION:
3459     AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
3460     break;
3461   }
3462 
3463   return *AA;
3464 }
3465 
3466 AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP,
3467                                                         Attributor &A) {
3468   AAExecutionDomainFunction *AA = nullptr;
3469   switch (IRP.getPositionKind()) {
3470   case IRPosition::IRP_INVALID:
3471   case IRPosition::IRP_FLOAT:
3472   case IRPosition::IRP_ARGUMENT:
3473   case IRPosition::IRP_CALL_SITE_ARGUMENT:
3474   case IRPosition::IRP_RETURNED:
3475   case IRPosition::IRP_CALL_SITE_RETURNED:
3476   case IRPosition::IRP_CALL_SITE:
3477     llvm_unreachable(
3478         "AAExecutionDomain can only be created for function position!");
3479   case IRPosition::IRP_FUNCTION:
3480     AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
3481     break;
3482   }
3483 
3484   return *AA;
3485 }
3486 
3487 AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
3488                                                   Attributor &A) {
3489   AAHeapToSharedFunction *AA = nullptr;
3490   switch (IRP.getPositionKind()) {
3491   case IRPosition::IRP_INVALID:
3492   case IRPosition::IRP_FLOAT:
3493   case IRPosition::IRP_ARGUMENT:
3494   case IRPosition::IRP_CALL_SITE_ARGUMENT:
3495   case IRPosition::IRP_RETURNED:
3496   case IRPosition::IRP_CALL_SITE_RETURNED:
3497   case IRPosition::IRP_CALL_SITE:
3498     llvm_unreachable(
3499         "AAHeapToShared can only be created for function position!");
3500   case IRPosition::IRP_FUNCTION:
3501     AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
3502     break;
3503   }
3504 
3505   return *AA;
3506 }
3507 
3508 AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
3509                                               Attributor &A) {
3510   AAKernelInfo *AA = nullptr;
3511   switch (IRP.getPositionKind()) {
3512   case IRPosition::IRP_INVALID:
3513   case IRPosition::IRP_FLOAT:
3514   case IRPosition::IRP_ARGUMENT:
3515   case IRPosition::IRP_RETURNED:
3516   case IRPosition::IRP_CALL_SITE_RETURNED:
3517   case IRPosition::IRP_CALL_SITE_ARGUMENT:
3518     llvm_unreachable("KernelInfo can only be created for function position!");
3519   case IRPosition::IRP_CALL_SITE:
3520     AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
3521     break;
3522   case IRPosition::IRP_FUNCTION:
3523     AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
3524     break;
3525   }
3526 
3527   return *AA;
3528 }
3529 
3530 PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
3531   if (!containsOpenMP(M))
3532     return PreservedAnalyses::all();
3533   if (DisableOpenMPOptimizations)
3534     return PreservedAnalyses::all();
3535 
3536   FunctionAnalysisManager &FAM =
3537       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
3538   KernelSet Kernels = getDeviceKernels(M);
3539 
3540   auto IsCalled = [&](Function &F) {
3541     if (Kernels.contains(&F))
3542       return true;
3543     for (const User *U : F.users())
3544       if (!isa<BlockAddress>(U))
3545         return true;
3546     return false;
3547   };
3548 
3549   auto EmitRemark = [&](Function &F) {
3550     auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
3551     ORE.emit([&]() {
3552       OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "InternalizationFailure", &F);
3553       return ORA << "Could not internalize function. "
3554                  << "Some optimizations may not be possible.";
3555     });
3556   };
3557 
3558   // Create internal copies of each function if this is a kernel Module. This
3559   // allows iterprocedural passes to see every call edge.
3560   DenseSet<const Function *> InternalizedFuncs;
3561   if (isOpenMPDevice(M))
3562     for (Function &F : M)
3563       if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F)) {
3564         if (Attributor::internalizeFunction(F, /* Force */ true)) {
3565           InternalizedFuncs.insert(&F);
3566         } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
3567           EmitRemark(F);
3568         }
3569       }
3570 
3571   // Look at every function in the Module unless it was internalized.
3572   SmallVector<Function *, 16> SCC;
3573   for (Function &F : M)
3574     if (!F.isDeclaration() && !InternalizedFuncs.contains(&F))
3575       SCC.push_back(&F);
3576 
3577   if (SCC.empty())
3578     return PreservedAnalyses::all();
3579 
3580   AnalysisGetter AG(FAM);
3581 
3582   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
3583     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
3584   };
3585 
3586   BumpPtrAllocator Allocator;
3587   CallGraphUpdater CGUpdater;
3588 
3589   SetVector<Function *> Functions(SCC.begin(), SCC.end());
3590   OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels);
3591 
3592   unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
3593   Attributor A(Functions, InfoCache, CGUpdater, nullptr, true, false,
3594                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
3595 
3596   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
3597   bool Changed = OMPOpt.run(true);
3598   if (Changed)
3599     return PreservedAnalyses::none();
3600 
3601   return PreservedAnalyses::all();
3602 }
3603 
3604 PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
3605                                           CGSCCAnalysisManager &AM,
3606                                           LazyCallGraph &CG,
3607                                           CGSCCUpdateResult &UR) {
3608   if (!containsOpenMP(*C.begin()->getFunction().getParent()))
3609     return PreservedAnalyses::all();
3610   if (DisableOpenMPOptimizations)
3611     return PreservedAnalyses::all();
3612 
3613   SmallVector<Function *, 16> SCC;
3614   // If there are kernels in the module, we have to run on all SCC's.
3615   for (LazyCallGraph::Node &N : C) {
3616     Function *Fn = &N.getFunction();
3617     SCC.push_back(Fn);
3618   }
3619 
3620   if (SCC.empty())
3621     return PreservedAnalyses::all();
3622 
3623   Module &M = *C.begin()->getFunction().getParent();
3624 
3625   KernelSet Kernels = getDeviceKernels(M);
3626 
3627   FunctionAnalysisManager &FAM =
3628       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
3629 
3630   AnalysisGetter AG(FAM);
3631 
3632   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
3633     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
3634   };
3635 
3636   BumpPtrAllocator Allocator;
3637   CallGraphUpdater CGUpdater;
3638   CGUpdater.initialize(CG, C, AM, UR);
3639 
3640   SetVector<Function *> Functions(SCC.begin(), SCC.end());
3641   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
3642                                 /*CGSCC*/ Functions, Kernels);
3643 
3644   unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
3645   Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
3646                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
3647 
3648   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
3649   bool Changed = OMPOpt.run(false);
3650   if (Changed)
3651     return PreservedAnalyses::none();
3652 
3653   return PreservedAnalyses::all();
3654 }
3655 
3656 namespace {
3657 
3658 struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass {
3659   CallGraphUpdater CGUpdater;
3660   static char ID;
3661 
3662   OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) {
3663     initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry());
3664   }
3665 
3666   void getAnalysisUsage(AnalysisUsage &AU) const override {
3667     CallGraphSCCPass::getAnalysisUsage(AU);
3668   }
3669 
3670   bool runOnSCC(CallGraphSCC &CGSCC) override {
3671     if (!containsOpenMP(CGSCC.getCallGraph().getModule()))
3672       return false;
3673     if (DisableOpenMPOptimizations || skipSCC(CGSCC))
3674       return false;
3675 
3676     SmallVector<Function *, 16> SCC;
3677     // If there are kernels in the module, we have to run on all SCC's.
3678     for (CallGraphNode *CGN : CGSCC) {
3679       Function *Fn = CGN->getFunction();
3680       if (!Fn || Fn->isDeclaration())
3681         continue;
3682       SCC.push_back(Fn);
3683     }
3684 
3685     if (SCC.empty())
3686       return false;
3687 
3688     Module &M = CGSCC.getCallGraph().getModule();
3689     KernelSet Kernels = getDeviceKernels(M);
3690 
3691     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
3692     CGUpdater.initialize(CG, CGSCC);
3693 
3694     // Maintain a map of functions to avoid rebuilding the ORE
3695     DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap;
3696     auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
3697       std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
3698       if (!ORE)
3699         ORE = std::make_unique<OptimizationRemarkEmitter>(F);
3700       return *ORE;
3701     };
3702 
3703     AnalysisGetter AG;
3704     SetVector<Function *> Functions(SCC.begin(), SCC.end());
3705     BumpPtrAllocator Allocator;
3706     OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG,
3707                                   Allocator,
3708                                   /*CGSCC*/ Functions, Kernels);
3709 
3710     unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
3711     Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
3712                  MaxFixpointIterations, OREGetter, DEBUG_TYPE);
3713 
3714     OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
3715     return OMPOpt.run(false);
3716   }
3717 
3718   bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
3719 };
3720 
3721 } // end anonymous namespace
3722 
3723 KernelSet llvm::omp::getDeviceKernels(Module &M) {
3724   // TODO: Create a more cross-platform way of determining device kernels.
3725   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
3726   KernelSet Kernels;
3727 
3728   if (!MD)
3729     return Kernels;
3730 
3731   for (auto *Op : MD->operands()) {
3732     if (Op->getNumOperands() < 2)
3733       continue;
3734     MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
3735     if (!KindID || KindID->getString() != "kernel")
3736       continue;
3737 
3738     Function *KernelFn =
3739         mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
3740     if (!KernelFn)
3741       continue;
3742 
3743     ++NumOpenMPTargetRegionKernels;
3744 
3745     Kernels.insert(KernelFn);
3746   }
3747 
3748   return Kernels;
3749 }
3750 
3751 bool llvm::omp::containsOpenMP(Module &M) {
3752   Metadata *MD = M.getModuleFlag("openmp");
3753   if (!MD)
3754     return false;
3755 
3756   return true;
3757 }
3758 
3759 bool llvm::omp::isOpenMPDevice(Module &M) {
3760   Metadata *MD = M.getModuleFlag("openmp-device");
3761   if (!MD)
3762     return false;
3763 
3764   return true;
3765 }
3766 
3767 char OpenMPOptCGSCCLegacyPass::ID = 0;
3768 
3769 INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
3770                       "OpenMP specific optimizations", false, false)
3771 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
3772 INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
3773                     "OpenMP specific optimizations", false, false)
3774 
3775 Pass *llvm::createOpenMPOptCGSCCLegacyPass() {
3776   return new OpenMPOptCGSCCLegacyPass();
3777 }
3778