1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpenMP specific optimizations:
10 //
11 // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12 // - Replacing globalized device memory with stack memory.
13 // - Replacing globalized device memory with shared memory.
14 // - Parallel region merging.
15 // - Transforming generic-mode device kernels to SPMD mode.
16 // - Specializing the state machine for generic-mode device kernels.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "llvm/Transforms/IPO/OpenMPOpt.h"
21 
22 #include "llvm/ADT/EnumeratedArray.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/Statistic.h"
25 #include "llvm/Analysis/CallGraph.h"
26 #include "llvm/Analysis/CallGraphSCCPass.h"
27 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
28 #include "llvm/Analysis/ValueTracking.h"
29 #include "llvm/Frontend/OpenMP/OMPConstants.h"
30 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
31 #include "llvm/IR/Assumptions.h"
32 #include "llvm/IR/DiagnosticInfo.h"
33 #include "llvm/IR/GlobalValue.h"
34 #include "llvm/IR/Instruction.h"
35 #include "llvm/IR/IntrinsicInst.h"
36 #include "llvm/InitializePasses.h"
37 #include "llvm/Support/CommandLine.h"
38 #include "llvm/Transforms/IPO.h"
39 #include "llvm/Transforms/IPO/Attributor.h"
40 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
41 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
42 #include "llvm/Transforms/Utils/CodeExtractor.h"
43 
44 using namespace llvm;
45 using namespace omp;
46 
47 #define DEBUG_TYPE "openmp-opt"
48 
49 static cl::opt<bool> DisableOpenMPOptimizations(
50     "openmp-opt-disable", cl::ZeroOrMore,
51     cl::desc("Disable OpenMP specific optimizations."), cl::Hidden,
52     cl::init(false));
53 
54 static cl::opt<bool> EnableParallelRegionMerging(
55     "openmp-opt-enable-merging", cl::ZeroOrMore,
56     cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
57     cl::init(false));
58 
59 static cl::opt<bool>
60     DisableInternalization("openmp-opt-disable-internalization", cl::ZeroOrMore,
61                            cl::desc("Disable function internalization."),
62                            cl::Hidden, cl::init(false));
63 
64 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
65                                     cl::Hidden);
66 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
67                                         cl::init(false), cl::Hidden);
68 
69 static cl::opt<bool> HideMemoryTransferLatency(
70     "openmp-hide-memory-transfer-latency",
71     cl::desc("[WIP] Tries to hide the latency of host to device memory"
72              " transfers"),
73     cl::Hidden, cl::init(false));
74 
75 static cl::opt<bool> DisableOpenMPOptDeglobalization(
76     "openmp-opt-disable-deglobalization", cl::ZeroOrMore,
77     cl::desc("Disable OpenMP optimizations involving deglobalization."),
78     cl::Hidden, cl::init(false));
79 
80 static cl::opt<bool> DisableOpenMPOptSPMDization(
81     "openmp-opt-disable-spmdization", cl::ZeroOrMore,
82     cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
83     cl::Hidden, cl::init(false));
84 
85 static cl::opt<bool> DisableOpenMPOptFolding(
86     "openmp-opt-disable-folding", cl::ZeroOrMore,
87     cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
88     cl::init(false));
89 
90 static cl::opt<bool> DisableOpenMPOptStateMachineRewrite(
91     "openmp-opt-disable-state-machine-rewrite", cl::ZeroOrMore,
92     cl::desc("Disable OpenMP optimizations that replace the state machine."),
93     cl::Hidden, cl::init(false));
94 
95 static cl::opt<bool> PrintModuleAfterOptimizations(
96     "openmp-opt-print-module", cl::ZeroOrMore,
97     cl::desc("Print the current module after OpenMP optimizations."),
98     cl::Hidden, cl::init(false));
99 
100 static cl::opt<bool> AlwaysInlineDeviceFunctions(
101     "openmp-opt-inline-device", cl::ZeroOrMore,
102     cl::desc("Inline all applicible functions on the device."), cl::Hidden,
103     cl::init(false));
104 
105 static cl::opt<bool>
106     EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::ZeroOrMore,
107                          cl::desc("Enables more verbose remarks."), cl::Hidden,
108                          cl::init(false));
109 
110 STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
111           "Number of OpenMP runtime calls deduplicated");
112 STATISTIC(NumOpenMPParallelRegionsDeleted,
113           "Number of OpenMP parallel regions deleted");
114 STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
115           "Number of OpenMP runtime functions identified");
116 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
117           "Number of OpenMP runtime function uses identified");
118 STATISTIC(NumOpenMPTargetRegionKernels,
119           "Number of OpenMP target region entry points (=kernels) identified");
120 STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
121           "Number of OpenMP target region entry points (=kernels) executed in "
122           "SPMD-mode instead of generic-mode");
123 STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
124           "Number of OpenMP target region entry points (=kernels) executed in "
125           "generic-mode without a state machines");
126 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
127           "Number of OpenMP target region entry points (=kernels) executed in "
128           "generic-mode with customized state machines with fallback");
129 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
130           "Number of OpenMP target region entry points (=kernels) executed in "
131           "generic-mode with customized state machines without fallback");
132 STATISTIC(
133     NumOpenMPParallelRegionsReplacedInGPUStateMachine,
134     "Number of OpenMP parallel regions replaced with ID in GPU state machines");
135 STATISTIC(NumOpenMPParallelRegionsMerged,
136           "Number of OpenMP parallel regions merged");
137 STATISTIC(NumBytesMovedToSharedMemory,
138           "Amount of memory pushed to shared memory");
139 
140 #if !defined(NDEBUG)
141 static constexpr auto TAG = "[" DEBUG_TYPE "]";
142 #endif
143 
144 namespace {
145 
146 enum class AddressSpace : unsigned {
147   Generic = 0,
148   Global = 1,
149   Shared = 3,
150   Constant = 4,
151   Local = 5,
152 };
153 
154 struct AAHeapToShared;
155 
156 struct AAICVTracker;
157 
158 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
159 /// Attributor runs.
160 struct OMPInformationCache : public InformationCache {
161   OMPInformationCache(Module &M, AnalysisGetter &AG,
162                       BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC,
163                       SmallPtrSetImpl<Kernel> &Kernels)
164       : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M),
165         Kernels(Kernels) {
166 
167     OMPBuilder.initialize();
168     initializeRuntimeFunctions();
169     initializeInternalControlVars();
170   }
171 
172   /// Generic information that describes an internal control variable.
173   struct InternalControlVarInfo {
174     /// The kind, as described by InternalControlVar enum.
175     InternalControlVar Kind;
176 
177     /// The name of the ICV.
178     StringRef Name;
179 
180     /// Environment variable associated with this ICV.
181     StringRef EnvVarName;
182 
183     /// Initial value kind.
184     ICVInitValue InitKind;
185 
186     /// Initial value.
187     ConstantInt *InitValue;
188 
189     /// Setter RTL function associated with this ICV.
190     RuntimeFunction Setter;
191 
192     /// Getter RTL function associated with this ICV.
193     RuntimeFunction Getter;
194 
195     /// RTL Function corresponding to the override clause of this ICV
196     RuntimeFunction Clause;
197   };
198 
199   /// Generic information that describes a runtime function
200   struct RuntimeFunctionInfo {
201 
202     /// The kind, as described by the RuntimeFunction enum.
203     RuntimeFunction Kind;
204 
205     /// The name of the function.
206     StringRef Name;
207 
208     /// Flag to indicate a variadic function.
209     bool IsVarArg;
210 
211     /// The return type of the function.
212     Type *ReturnType;
213 
214     /// The argument types of the function.
215     SmallVector<Type *, 8> ArgumentTypes;
216 
217     /// The declaration if available.
218     Function *Declaration = nullptr;
219 
220     /// Uses of this runtime function per function containing the use.
221     using UseVector = SmallVector<Use *, 16>;
222 
223     /// Clear UsesMap for runtime function.
224     void clearUsesMap() { UsesMap.clear(); }
225 
226     /// Boolean conversion that is true if the runtime function was found.
227     operator bool() const { return Declaration; }
228 
229     /// Return the vector of uses in function \p F.
230     UseVector &getOrCreateUseVector(Function *F) {
231       std::shared_ptr<UseVector> &UV = UsesMap[F];
232       if (!UV)
233         UV = std::make_shared<UseVector>();
234       return *UV;
235     }
236 
237     /// Return the vector of uses in function \p F or `nullptr` if there are
238     /// none.
239     const UseVector *getUseVector(Function &F) const {
240       auto I = UsesMap.find(&F);
241       if (I != UsesMap.end())
242         return I->second.get();
243       return nullptr;
244     }
245 
246     /// Return how many functions contain uses of this runtime function.
247     size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
248 
249     /// Return the number of arguments (or the minimal number for variadic
250     /// functions).
251     size_t getNumArgs() const { return ArgumentTypes.size(); }
252 
253     /// Run the callback \p CB on each use and forget the use if the result is
254     /// true. The callback will be fed the function in which the use was
255     /// encountered as second argument.
256     void foreachUse(SmallVectorImpl<Function *> &SCC,
257                     function_ref<bool(Use &, Function &)> CB) {
258       for (Function *F : SCC)
259         foreachUse(CB, F);
260     }
261 
262     /// Run the callback \p CB on each use within the function \p F and forget
263     /// the use if the result is true.
264     void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
265       SmallVector<unsigned, 8> ToBeDeleted;
266       ToBeDeleted.clear();
267 
268       unsigned Idx = 0;
269       UseVector &UV = getOrCreateUseVector(F);
270 
271       for (Use *U : UV) {
272         if (CB(*U, *F))
273           ToBeDeleted.push_back(Idx);
274         ++Idx;
275       }
276 
277       // Remove the to-be-deleted indices in reverse order as prior
278       // modifications will not modify the smaller indices.
279       while (!ToBeDeleted.empty()) {
280         unsigned Idx = ToBeDeleted.pop_back_val();
281         UV[Idx] = UV.back();
282         UV.pop_back();
283       }
284     }
285 
286   private:
287     /// Map from functions to all uses of this runtime function contained in
288     /// them.
289     DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
290 
291   public:
292     /// Iterators for the uses of this runtime function.
293     decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
294     decltype(UsesMap)::iterator end() { return UsesMap.end(); }
295   };
296 
297   /// An OpenMP-IR-Builder instance
298   OpenMPIRBuilder OMPBuilder;
299 
300   /// Map from runtime function kind to the runtime function description.
301   EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
302                   RuntimeFunction::OMPRTL___last>
303       RFIs;
304 
305   /// Map from function declarations/definitions to their runtime enum type.
306   DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
307 
308   /// Map from ICV kind to the ICV description.
309   EnumeratedArray<InternalControlVarInfo, InternalControlVar,
310                   InternalControlVar::ICV___last>
311       ICVs;
312 
313   /// Helper to initialize all internal control variable information for those
314   /// defined in OMPKinds.def.
315   void initializeInternalControlVars() {
316 #define ICV_RT_SET(_Name, RTL)                                                 \
317   {                                                                            \
318     auto &ICV = ICVs[_Name];                                                   \
319     ICV.Setter = RTL;                                                          \
320   }
321 #define ICV_RT_GET(Name, RTL)                                                  \
322   {                                                                            \
323     auto &ICV = ICVs[Name];                                                    \
324     ICV.Getter = RTL;                                                          \
325   }
326 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init)                           \
327   {                                                                            \
328     auto &ICV = ICVs[Enum];                                                    \
329     ICV.Name = _Name;                                                          \
330     ICV.Kind = Enum;                                                           \
331     ICV.InitKind = Init;                                                       \
332     ICV.EnvVarName = _EnvVarName;                                              \
333     switch (ICV.InitKind) {                                                    \
334     case ICV_IMPLEMENTATION_DEFINED:                                           \
335       ICV.InitValue = nullptr;                                                 \
336       break;                                                                   \
337     case ICV_ZERO:                                                             \
338       ICV.InitValue = ConstantInt::get(                                        \
339           Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0);                \
340       break;                                                                   \
341     case ICV_FALSE:                                                            \
342       ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext());    \
343       break;                                                                   \
344     case ICV_LAST:                                                             \
345       break;                                                                   \
346     }                                                                          \
347   }
348 #include "llvm/Frontend/OpenMP/OMPKinds.def"
349   }
350 
351   /// Returns true if the function declaration \p F matches the runtime
352   /// function types, that is, return type \p RTFRetType, and argument types
353   /// \p RTFArgTypes.
354   static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
355                                   SmallVector<Type *, 8> &RTFArgTypes) {
356     // TODO: We should output information to the user (under debug output
357     //       and via remarks).
358 
359     if (!F)
360       return false;
361     if (F->getReturnType() != RTFRetType)
362       return false;
363     if (F->arg_size() != RTFArgTypes.size())
364       return false;
365 
366     auto RTFTyIt = RTFArgTypes.begin();
367     for (Argument &Arg : F->args()) {
368       if (Arg.getType() != *RTFTyIt)
369         return false;
370 
371       ++RTFTyIt;
372     }
373 
374     return true;
375   }
376 
377   // Helper to collect all uses of the declaration in the UsesMap.
378   unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
379     unsigned NumUses = 0;
380     if (!RFI.Declaration)
381       return NumUses;
382     OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
383 
384     if (CollectStats) {
385       NumOpenMPRuntimeFunctionsIdentified += 1;
386       NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
387     }
388 
389     // TODO: We directly convert uses into proper calls and unknown uses.
390     for (Use &U : RFI.Declaration->uses()) {
391       if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
392         if (ModuleSlice.count(UserI->getFunction())) {
393           RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
394           ++NumUses;
395         }
396       } else {
397         RFI.getOrCreateUseVector(nullptr).push_back(&U);
398         ++NumUses;
399       }
400     }
401     return NumUses;
402   }
403 
404   // Helper function to recollect uses of a runtime function.
405   void recollectUsesForFunction(RuntimeFunction RTF) {
406     auto &RFI = RFIs[RTF];
407     RFI.clearUsesMap();
408     collectUses(RFI, /*CollectStats*/ false);
409   }
410 
411   // Helper function to recollect uses of all runtime functions.
412   void recollectUses() {
413     for (int Idx = 0; Idx < RFIs.size(); ++Idx)
414       recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
415   }
416 
417   /// Helper to initialize all runtime function information for those defined
418   /// in OpenMPKinds.def.
419   void initializeRuntimeFunctions() {
420     Module &M = *((*ModuleSlice.begin())->getParent());
421 
422     // Helper macros for handling __VA_ARGS__ in OMP_RTL
423 #define OMP_TYPE(VarName, ...)                                                 \
424   Type *VarName = OMPBuilder.VarName;                                          \
425   (void)VarName;
426 
427 #define OMP_ARRAY_TYPE(VarName, ...)                                           \
428   ArrayType *VarName##Ty = OMPBuilder.VarName##Ty;                             \
429   (void)VarName##Ty;                                                           \
430   PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy;                     \
431   (void)VarName##PtrTy;
432 
433 #define OMP_FUNCTION_TYPE(VarName, ...)                                        \
434   FunctionType *VarName = OMPBuilder.VarName;                                  \
435   (void)VarName;                                                               \
436   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
437   (void)VarName##Ptr;
438 
439 #define OMP_STRUCT_TYPE(VarName, ...)                                          \
440   StructType *VarName = OMPBuilder.VarName;                                    \
441   (void)VarName;                                                               \
442   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
443   (void)VarName##Ptr;
444 
445 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
446   {                                                                            \
447     SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
448     Function *F = M.getFunction(_Name);                                        \
449     RTLFunctions.insert(F);                                                    \
450     if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) {           \
451       RuntimeFunctionIDMap[F] = _Enum;                                         \
452       F->removeFnAttr(Attribute::NoInline);                                    \
453       auto &RFI = RFIs[_Enum];                                                 \
454       RFI.Kind = _Enum;                                                        \
455       RFI.Name = _Name;                                                        \
456       RFI.IsVarArg = _IsVarArg;                                                \
457       RFI.ReturnType = OMPBuilder._ReturnType;                                 \
458       RFI.ArgumentTypes = std::move(ArgsTypes);                                \
459       RFI.Declaration = F;                                                     \
460       unsigned NumUses = collectUses(RFI);                                     \
461       (void)NumUses;                                                           \
462       LLVM_DEBUG({                                                             \
463         dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
464                << " found\n";                                                  \
465         if (RFI.Declaration)                                                   \
466           dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
467                  << RFI.getNumFunctionsWithUses()                              \
468                  << " different functions.\n";                                 \
469       });                                                                      \
470     }                                                                          \
471   }
472 #include "llvm/Frontend/OpenMP/OMPKinds.def"
473 
474     // TODO: We should attach the attributes defined in OMPKinds.def.
475   }
476 
477   /// Collection of known kernels (\see Kernel) in the module.
478   SmallPtrSetImpl<Kernel> &Kernels;
479 
480   /// Collection of known OpenMP runtime functions..
481   DenseSet<const Function *> RTLFunctions;
482 };
483 
484 template <typename Ty, bool InsertInvalidates = true>
485 struct BooleanStateWithSetVector : public BooleanState {
486   bool contains(const Ty &Elem) const { return Set.contains(Elem); }
487   bool insert(const Ty &Elem) {
488     if (InsertInvalidates)
489       BooleanState::indicatePessimisticFixpoint();
490     return Set.insert(Elem);
491   }
492 
493   const Ty &operator[](int Idx) const { return Set[Idx]; }
494   bool operator==(const BooleanStateWithSetVector &RHS) const {
495     return BooleanState::operator==(RHS) && Set == RHS.Set;
496   }
497   bool operator!=(const BooleanStateWithSetVector &RHS) const {
498     return !(*this == RHS);
499   }
500 
501   bool empty() const { return Set.empty(); }
502   size_t size() const { return Set.size(); }
503 
504   /// "Clamp" this state with \p RHS.
505   BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
506     BooleanState::operator^=(RHS);
507     Set.insert(RHS.Set.begin(), RHS.Set.end());
508     return *this;
509   }
510 
511 private:
512   /// A set to keep track of elements.
513   SetVector<Ty> Set;
514 
515 public:
516   typename decltype(Set)::iterator begin() { return Set.begin(); }
517   typename decltype(Set)::iterator end() { return Set.end(); }
518   typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
519   typename decltype(Set)::const_iterator end() const { return Set.end(); }
520 };
521 
522 template <typename Ty, bool InsertInvalidates = true>
523 using BooleanStateWithPtrSetVector =
524     BooleanStateWithSetVector<Ty *, InsertInvalidates>;
525 
526 struct KernelInfoState : AbstractState {
527   /// Flag to track if we reached a fixpoint.
528   bool IsAtFixpoint = false;
529 
530   /// The parallel regions (identified by the outlined parallel functions) that
531   /// can be reached from the associated function.
532   BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false>
533       ReachedKnownParallelRegions;
534 
535   /// State to track what parallel region we might reach.
536   BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
537 
538   /// State to track if we are in SPMD-mode, assumed or know, and why we decided
539   /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
540   /// false.
541   BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
542 
543   /// The __kmpc_target_init call in this kernel, if any. If we find more than
544   /// one we abort as the kernel is malformed.
545   CallBase *KernelInitCB = nullptr;
546 
547   /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
548   /// one we abort as the kernel is malformed.
549   CallBase *KernelDeinitCB = nullptr;
550 
551   /// Flag to indicate if the associated function is a kernel entry.
552   bool IsKernelEntry = false;
553 
554   /// State to track what kernel entries can reach the associated function.
555   BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
556 
557   /// State to indicate if we can track parallel level of the associated
558   /// function. We will give up tracking if we encounter unknown caller or the
559   /// caller is __kmpc_parallel_51.
560   BooleanStateWithSetVector<uint8_t> ParallelLevels;
561 
562   /// Abstract State interface
563   ///{
564 
565   KernelInfoState() {}
566   KernelInfoState(bool BestState) {
567     if (!BestState)
568       indicatePessimisticFixpoint();
569   }
570 
571   /// See AbstractState::isValidState(...)
572   bool isValidState() const override { return true; }
573 
574   /// See AbstractState::isAtFixpoint(...)
575   bool isAtFixpoint() const override { return IsAtFixpoint; }
576 
577   /// See AbstractState::indicatePessimisticFixpoint(...)
578   ChangeStatus indicatePessimisticFixpoint() override {
579     IsAtFixpoint = true;
580     SPMDCompatibilityTracker.indicatePessimisticFixpoint();
581     ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
582     return ChangeStatus::CHANGED;
583   }
584 
585   /// See AbstractState::indicateOptimisticFixpoint(...)
586   ChangeStatus indicateOptimisticFixpoint() override {
587     IsAtFixpoint = true;
588     return ChangeStatus::UNCHANGED;
589   }
590 
591   /// Return the assumed state
592   KernelInfoState &getAssumed() { return *this; }
593   const KernelInfoState &getAssumed() const { return *this; }
594 
595   bool operator==(const KernelInfoState &RHS) const {
596     if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
597       return false;
598     if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
599       return false;
600     if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
601       return false;
602     if (ReachingKernelEntries != RHS.ReachingKernelEntries)
603       return false;
604     return true;
605   }
606 
607   /// Returns true if this kernel contains any OpenMP parallel regions.
608   bool mayContainParallelRegion() {
609     return !ReachedKnownParallelRegions.empty() ||
610            !ReachedUnknownParallelRegions.empty();
611   }
612 
613   /// Return empty set as the best state of potential values.
614   static KernelInfoState getBestState() { return KernelInfoState(true); }
615 
616   static KernelInfoState getBestState(KernelInfoState &KIS) {
617     return getBestState();
618   }
619 
620   /// Return full set as the worst state of potential values.
621   static KernelInfoState getWorstState() { return KernelInfoState(false); }
622 
623   /// "Clamp" this state with \p KIS.
624   KernelInfoState operator^=(const KernelInfoState &KIS) {
625     // Do not merge two different _init and _deinit call sites.
626     if (KIS.KernelInitCB) {
627       if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
628         indicatePessimisticFixpoint();
629       KernelInitCB = KIS.KernelInitCB;
630     }
631     if (KIS.KernelDeinitCB) {
632       if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
633         indicatePessimisticFixpoint();
634       KernelDeinitCB = KIS.KernelDeinitCB;
635     }
636     SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
637     ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
638     ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
639     return *this;
640   }
641 
642   KernelInfoState operator&=(const KernelInfoState &KIS) {
643     return (*this ^= KIS);
644   }
645 
646   ///}
647 };
648 
649 /// Used to map the values physically (in the IR) stored in an offload
650 /// array, to a vector in memory.
651 struct OffloadArray {
652   /// Physical array (in the IR).
653   AllocaInst *Array = nullptr;
654   /// Mapped values.
655   SmallVector<Value *, 8> StoredValues;
656   /// Last stores made in the offload array.
657   SmallVector<StoreInst *, 8> LastAccesses;
658 
659   OffloadArray() = default;
660 
661   /// Initializes the OffloadArray with the values stored in \p Array before
662   /// instruction \p Before is reached. Returns false if the initialization
663   /// fails.
664   /// This MUST be used immediately after the construction of the object.
665   bool initialize(AllocaInst &Array, Instruction &Before) {
666     if (!Array.getAllocatedType()->isArrayTy())
667       return false;
668 
669     if (!getValues(Array, Before))
670       return false;
671 
672     this->Array = &Array;
673     return true;
674   }
675 
676   static const unsigned DeviceIDArgNum = 1;
677   static const unsigned BasePtrsArgNum = 3;
678   static const unsigned PtrsArgNum = 4;
679   static const unsigned SizesArgNum = 5;
680 
681 private:
682   /// Traverses the BasicBlock where \p Array is, collecting the stores made to
683   /// \p Array, leaving StoredValues with the values stored before the
684   /// instruction \p Before is reached.
685   bool getValues(AllocaInst &Array, Instruction &Before) {
686     // Initialize container.
687     const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
688     StoredValues.assign(NumValues, nullptr);
689     LastAccesses.assign(NumValues, nullptr);
690 
691     // TODO: This assumes the instruction \p Before is in the same
692     //  BasicBlock as Array. Make it general, for any control flow graph.
693     BasicBlock *BB = Array.getParent();
694     if (BB != Before.getParent())
695       return false;
696 
697     const DataLayout &DL = Array.getModule()->getDataLayout();
698     const unsigned int PointerSize = DL.getPointerSize();
699 
700     for (Instruction &I : *BB) {
701       if (&I == &Before)
702         break;
703 
704       if (!isa<StoreInst>(&I))
705         continue;
706 
707       auto *S = cast<StoreInst>(&I);
708       int64_t Offset = -1;
709       auto *Dst =
710           GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
711       if (Dst == &Array) {
712         int64_t Idx = Offset / PointerSize;
713         StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
714         LastAccesses[Idx] = S;
715       }
716     }
717 
718     return isFilled();
719   }
720 
721   /// Returns true if all values in StoredValues and
722   /// LastAccesses are not nullptrs.
723   bool isFilled() {
724     const unsigned NumValues = StoredValues.size();
725     for (unsigned I = 0; I < NumValues; ++I) {
726       if (!StoredValues[I] || !LastAccesses[I])
727         return false;
728     }
729 
730     return true;
731   }
732 };
733 
734 struct OpenMPOpt {
735 
736   using OptimizationRemarkGetter =
737       function_ref<OptimizationRemarkEmitter &(Function *)>;
738 
739   OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
740             OptimizationRemarkGetter OREGetter,
741             OMPInformationCache &OMPInfoCache, Attributor &A)
742       : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
743         OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
744 
745   /// Check if any remarks are enabled for openmp-opt
746   bool remarksEnabled() {
747     auto &Ctx = M.getContext();
748     return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
749   }
750 
751   /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
752   bool run(bool IsModulePass) {
753     if (SCC.empty())
754       return false;
755 
756     bool Changed = false;
757 
758     LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
759                       << " functions in a slice with "
760                       << OMPInfoCache.ModuleSlice.size() << " functions\n");
761 
762     if (IsModulePass) {
763       Changed |= runAttributor(IsModulePass);
764 
765       // Recollect uses, in case Attributor deleted any.
766       OMPInfoCache.recollectUses();
767 
768       // TODO: This should be folded into buildCustomStateMachine.
769       Changed |= rewriteDeviceCodeStateMachine();
770 
771       if (remarksEnabled())
772         analysisGlobalization();
773     } else {
774       if (PrintICVValues)
775         printICVs();
776       if (PrintOpenMPKernels)
777         printKernels();
778 
779       Changed |= runAttributor(IsModulePass);
780 
781       // Recollect uses, in case Attributor deleted any.
782       OMPInfoCache.recollectUses();
783 
784       Changed |= deleteParallelRegions();
785 
786       if (HideMemoryTransferLatency)
787         Changed |= hideMemTransfersLatency();
788       Changed |= deduplicateRuntimeCalls();
789       if (EnableParallelRegionMerging) {
790         if (mergeParallelRegions()) {
791           deduplicateRuntimeCalls();
792           Changed = true;
793         }
794       }
795     }
796 
797     return Changed;
798   }
799 
800   /// Print initial ICV values for testing.
801   /// FIXME: This should be done from the Attributor once it is added.
802   void printICVs() const {
803     InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
804                                  ICV_proc_bind};
805 
806     for (Function *F : OMPInfoCache.ModuleSlice) {
807       for (auto ICV : ICVs) {
808         auto ICVInfo = OMPInfoCache.ICVs[ICV];
809         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
810           return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
811                      << " Value: "
812                      << (ICVInfo.InitValue
813                              ? toString(ICVInfo.InitValue->getValue(), 10, true)
814                              : "IMPLEMENTATION_DEFINED");
815         };
816 
817         emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
818       }
819     }
820   }
821 
822   /// Print OpenMP GPU kernels for testing.
823   void printKernels() const {
824     for (Function *F : SCC) {
825       if (!OMPInfoCache.Kernels.count(F))
826         continue;
827 
828       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
829         return ORA << "OpenMP GPU kernel "
830                    << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
831       };
832 
833       emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
834     }
835   }
836 
837   /// Return the call if \p U is a callee use in a regular call. If \p RFI is
838   /// given it has to be the callee or a nullptr is returned.
839   static CallInst *getCallIfRegularCall(
840       Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
841     CallInst *CI = dyn_cast<CallInst>(U.getUser());
842     if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
843         (!RFI ||
844          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
845       return CI;
846     return nullptr;
847   }
848 
849   /// Return the call if \p V is a regular call. If \p RFI is given it has to be
850   /// the callee or a nullptr is returned.
851   static CallInst *getCallIfRegularCall(
852       Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
853     CallInst *CI = dyn_cast<CallInst>(&V);
854     if (CI && !CI->hasOperandBundles() &&
855         (!RFI ||
856          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
857       return CI;
858     return nullptr;
859   }
860 
861 private:
862   /// Merge parallel regions when it is safe.
863   bool mergeParallelRegions() {
864     const unsigned CallbackCalleeOperand = 2;
865     const unsigned CallbackFirstArgOperand = 3;
866     using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
867 
868     // Check if there are any __kmpc_fork_call calls to merge.
869     OMPInformationCache::RuntimeFunctionInfo &RFI =
870         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
871 
872     if (!RFI.Declaration)
873       return false;
874 
875     // Unmergable calls that prevent merging a parallel region.
876     OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
877         OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
878         OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
879     };
880 
881     bool Changed = false;
882     LoopInfo *LI = nullptr;
883     DominatorTree *DT = nullptr;
884 
885     SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
886 
887     BasicBlock *StartBB = nullptr, *EndBB = nullptr;
888     auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
889                          BasicBlock &ContinuationIP) {
890       BasicBlock *CGStartBB = CodeGenIP.getBlock();
891       BasicBlock *CGEndBB =
892           SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
893       assert(StartBB != nullptr && "StartBB should not be null");
894       CGStartBB->getTerminator()->setSuccessor(0, StartBB);
895       assert(EndBB != nullptr && "EndBB should not be null");
896       EndBB->getTerminator()->setSuccessor(0, CGEndBB);
897     };
898 
899     auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
900                       Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
901       ReplacementValue = &Inner;
902       return CodeGenIP;
903     };
904 
905     auto FiniCB = [&](InsertPointTy CodeGenIP) {};
906 
907     /// Create a sequential execution region within a merged parallel region,
908     /// encapsulated in a master construct with a barrier for synchronization.
909     auto CreateSequentialRegion = [&](Function *OuterFn,
910                                       BasicBlock *OuterPredBB,
911                                       Instruction *SeqStartI,
912                                       Instruction *SeqEndI) {
913       // Isolate the instructions of the sequential region to a separate
914       // block.
915       BasicBlock *ParentBB = SeqStartI->getParent();
916       BasicBlock *SeqEndBB =
917           SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
918       BasicBlock *SeqAfterBB =
919           SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
920       BasicBlock *SeqStartBB =
921           SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
922 
923       assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
924              "Expected a different CFG");
925       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
926       ParentBB->getTerminator()->eraseFromParent();
927 
928       auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
929                            BasicBlock &ContinuationIP) {
930         BasicBlock *CGStartBB = CodeGenIP.getBlock();
931         BasicBlock *CGEndBB =
932             SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
933         assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
934         CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
935         assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
936         SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
937       };
938       auto FiniCB = [&](InsertPointTy CodeGenIP) {};
939 
940       // Find outputs from the sequential region to outside users and
941       // broadcast their values to them.
942       for (Instruction &I : *SeqStartBB) {
943         SmallPtrSet<Instruction *, 4> OutsideUsers;
944         for (User *Usr : I.users()) {
945           Instruction &UsrI = *cast<Instruction>(Usr);
946           // Ignore outputs to LT intrinsics, code extraction for the merged
947           // parallel region will fix them.
948           if (UsrI.isLifetimeStartOrEnd())
949             continue;
950 
951           if (UsrI.getParent() != SeqStartBB)
952             OutsideUsers.insert(&UsrI);
953         }
954 
955         if (OutsideUsers.empty())
956           continue;
957 
958         // Emit an alloca in the outer region to store the broadcasted
959         // value.
960         const DataLayout &DL = M.getDataLayout();
961         AllocaInst *AllocaI = new AllocaInst(
962             I.getType(), DL.getAllocaAddrSpace(), nullptr,
963             I.getName() + ".seq.output.alloc", &OuterFn->front().front());
964 
965         // Emit a store instruction in the sequential BB to update the
966         // value.
967         new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
968 
969         // Emit a load instruction and replace the use of the output value
970         // with it.
971         for (Instruction *UsrI : OutsideUsers) {
972           LoadInst *LoadI = new LoadInst(
973               I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
974           UsrI->replaceUsesOfWith(&I, LoadI);
975         }
976       }
977 
978       OpenMPIRBuilder::LocationDescription Loc(
979           InsertPointTy(ParentBB, ParentBB->end()), DL);
980       InsertPointTy SeqAfterIP =
981           OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
982 
983       OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
984 
985       BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
986 
987       LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
988                         << "\n");
989     };
990 
991     // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
992     // contained in BB and only separated by instructions that can be
993     // redundantly executed in parallel. The block BB is split before the first
994     // call (in MergableCIs) and after the last so the entire region we merge
995     // into a single parallel region is contained in a single basic block
996     // without any other instructions. We use the OpenMPIRBuilder to outline
997     // that block and call the resulting function via __kmpc_fork_call.
998     auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) {
999       // TODO: Change the interface to allow single CIs expanded, e.g, to
1000       // include an outer loop.
1001       assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1002 
1003       auto Remark = [&](OptimizationRemark OR) {
1004         OR << "Parallel region merged with parallel region"
1005            << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1006         for (auto *CI : llvm::drop_begin(MergableCIs)) {
1007           OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1008           if (CI != MergableCIs.back())
1009             OR << ", ";
1010         }
1011         return OR << ".";
1012       };
1013 
1014       emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1015 
1016       Function *OriginalFn = BB->getParent();
1017       LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1018                         << " parallel regions in " << OriginalFn->getName()
1019                         << "\n");
1020 
1021       // Isolate the calls to merge in a separate block.
1022       EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1023       BasicBlock *AfterBB =
1024           SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1025       StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1026                            "omp.par.merged");
1027 
1028       assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1029       const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1030       BB->getTerminator()->eraseFromParent();
1031 
1032       // Create sequential regions for sequential instructions that are
1033       // in-between mergable parallel regions.
1034       for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1035            It != End; ++It) {
1036         Instruction *ForkCI = *It;
1037         Instruction *NextForkCI = *(It + 1);
1038 
1039         // Continue if there are not in-between instructions.
1040         if (ForkCI->getNextNode() == NextForkCI)
1041           continue;
1042 
1043         CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1044                                NextForkCI->getPrevNode());
1045       }
1046 
1047       OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1048                                                DL);
1049       IRBuilder<>::InsertPoint AllocaIP(
1050           &OriginalFn->getEntryBlock(),
1051           OriginalFn->getEntryBlock().getFirstInsertionPt());
1052       // Create the merged parallel region with default proc binding, to
1053       // avoid overriding binding settings, and without explicit cancellation.
1054       InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1055           Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1056           OMP_PROC_BIND_default, /* IsCancellable */ false);
1057       BranchInst::Create(AfterBB, AfterIP.getBlock());
1058 
1059       // Perform the actual outlining.
1060       OMPInfoCache.OMPBuilder.finalize(OriginalFn,
1061                                        /* AllowExtractorSinking */ true);
1062 
1063       Function *OutlinedFn = MergableCIs.front()->getCaller();
1064 
1065       // Replace the __kmpc_fork_call calls with direct calls to the outlined
1066       // callbacks.
1067       SmallVector<Value *, 8> Args;
1068       for (auto *CI : MergableCIs) {
1069         Value *Callee =
1070             CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts();
1071         FunctionType *FT =
1072             cast<FunctionType>(Callee->getType()->getPointerElementType());
1073         Args.clear();
1074         Args.push_back(OutlinedFn->getArg(0));
1075         Args.push_back(OutlinedFn->getArg(1));
1076         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
1077              U < E; ++U)
1078           Args.push_back(CI->getArgOperand(U));
1079 
1080         CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
1081         if (CI->getDebugLoc())
1082           NewCI->setDebugLoc(CI->getDebugLoc());
1083 
1084         // Forward parameter attributes from the callback to the callee.
1085         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
1086              U < E; ++U)
1087           for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1088             NewCI->addParamAttr(
1089                 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1090 
1091         // Emit an explicit barrier to replace the implicit fork-join barrier.
1092         if (CI != MergableCIs.back()) {
1093           // TODO: Remove barrier if the merged parallel region includes the
1094           // 'nowait' clause.
1095           OMPInfoCache.OMPBuilder.createBarrier(
1096               InsertPointTy(NewCI->getParent(),
1097                             NewCI->getNextNode()->getIterator()),
1098               OMPD_parallel);
1099         }
1100 
1101         CI->eraseFromParent();
1102       }
1103 
1104       assert(OutlinedFn != OriginalFn && "Outlining failed");
1105       CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1106       CGUpdater.reanalyzeFunction(*OriginalFn);
1107 
1108       NumOpenMPParallelRegionsMerged += MergableCIs.size();
1109 
1110       return true;
1111     };
1112 
1113     // Helper function that identifes sequences of
1114     // __kmpc_fork_call uses in a basic block.
1115     auto DetectPRsCB = [&](Use &U, Function &F) {
1116       CallInst *CI = getCallIfRegularCall(U, &RFI);
1117       BB2PRMap[CI->getParent()].insert(CI);
1118 
1119       return false;
1120     };
1121 
1122     BB2PRMap.clear();
1123     RFI.foreachUse(SCC, DetectPRsCB);
1124     SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1125     // Find mergable parallel regions within a basic block that are
1126     // safe to merge, that is any in-between instructions can safely
1127     // execute in parallel after merging.
1128     // TODO: support merging across basic-blocks.
1129     for (auto &It : BB2PRMap) {
1130       auto &CIs = It.getSecond();
1131       if (CIs.size() < 2)
1132         continue;
1133 
1134       BasicBlock *BB = It.getFirst();
1135       SmallVector<CallInst *, 4> MergableCIs;
1136 
1137       /// Returns true if the instruction is mergable, false otherwise.
1138       /// A terminator instruction is unmergable by definition since merging
1139       /// works within a BB. Instructions before the mergable region are
1140       /// mergable if they are not calls to OpenMP runtime functions that may
1141       /// set different execution parameters for subsequent parallel regions.
1142       /// Instructions in-between parallel regions are mergable if they are not
1143       /// calls to any non-intrinsic function since that may call a non-mergable
1144       /// OpenMP runtime function.
1145       auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1146         // We do not merge across BBs, hence return false (unmergable) if the
1147         // instruction is a terminator.
1148         if (I.isTerminator())
1149           return false;
1150 
1151         if (!isa<CallInst>(&I))
1152           return true;
1153 
1154         CallInst *CI = cast<CallInst>(&I);
1155         if (IsBeforeMergableRegion) {
1156           Function *CalledFunction = CI->getCalledFunction();
1157           if (!CalledFunction)
1158             return false;
1159           // Return false (unmergable) if the call before the parallel
1160           // region calls an explicit affinity (proc_bind) or number of
1161           // threads (num_threads) compiler-generated function. Those settings
1162           // may be incompatible with following parallel regions.
1163           // TODO: ICV tracking to detect compatibility.
1164           for (const auto &RFI : UnmergableCallsInfo) {
1165             if (CalledFunction == RFI.Declaration)
1166               return false;
1167           }
1168         } else {
1169           // Return false (unmergable) if there is a call instruction
1170           // in-between parallel regions when it is not an intrinsic. It
1171           // may call an unmergable OpenMP runtime function in its callpath.
1172           // TODO: Keep track of possible OpenMP calls in the callpath.
1173           if (!isa<IntrinsicInst>(CI))
1174             return false;
1175         }
1176 
1177         return true;
1178       };
1179       // Find maximal number of parallel region CIs that are safe to merge.
1180       for (auto It = BB->begin(), End = BB->end(); It != End;) {
1181         Instruction &I = *It;
1182         ++It;
1183 
1184         if (CIs.count(&I)) {
1185           MergableCIs.push_back(cast<CallInst>(&I));
1186           continue;
1187         }
1188 
1189         // Continue expanding if the instruction is mergable.
1190         if (IsMergable(I, MergableCIs.empty()))
1191           continue;
1192 
1193         // Forward the instruction iterator to skip the next parallel region
1194         // since there is an unmergable instruction which can affect it.
1195         for (; It != End; ++It) {
1196           Instruction &SkipI = *It;
1197           if (CIs.count(&SkipI)) {
1198             LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1199                               << " due to " << I << "\n");
1200             ++It;
1201             break;
1202           }
1203         }
1204 
1205         // Store mergable regions found.
1206         if (MergableCIs.size() > 1) {
1207           MergableCIsVector.push_back(MergableCIs);
1208           LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1209                             << " parallel regions in block " << BB->getName()
1210                             << " of function " << BB->getParent()->getName()
1211                             << "\n";);
1212         }
1213 
1214         MergableCIs.clear();
1215       }
1216 
1217       if (!MergableCIsVector.empty()) {
1218         Changed = true;
1219 
1220         for (auto &MergableCIs : MergableCIsVector)
1221           Merge(MergableCIs, BB);
1222         MergableCIsVector.clear();
1223       }
1224     }
1225 
1226     if (Changed) {
1227       /// Re-collect use for fork calls, emitted barrier calls, and
1228       /// any emitted master/end_master calls.
1229       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1230       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1231       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1232       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1233     }
1234 
1235     return Changed;
1236   }
1237 
1238   /// Try to delete parallel regions if possible.
1239   bool deleteParallelRegions() {
1240     const unsigned CallbackCalleeOperand = 2;
1241 
1242     OMPInformationCache::RuntimeFunctionInfo &RFI =
1243         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1244 
1245     if (!RFI.Declaration)
1246       return false;
1247 
1248     bool Changed = false;
1249     auto DeleteCallCB = [&](Use &U, Function &) {
1250       CallInst *CI = getCallIfRegularCall(U);
1251       if (!CI)
1252         return false;
1253       auto *Fn = dyn_cast<Function>(
1254           CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1255       if (!Fn)
1256         return false;
1257       if (!Fn->onlyReadsMemory())
1258         return false;
1259       if (!Fn->hasFnAttribute(Attribute::WillReturn))
1260         return false;
1261 
1262       LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1263                         << CI->getCaller()->getName() << "\n");
1264 
1265       auto Remark = [&](OptimizationRemark OR) {
1266         return OR << "Removing parallel region with no side-effects.";
1267       };
1268       emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1269 
1270       CGUpdater.removeCallSite(*CI);
1271       CI->eraseFromParent();
1272       Changed = true;
1273       ++NumOpenMPParallelRegionsDeleted;
1274       return true;
1275     };
1276 
1277     RFI.foreachUse(SCC, DeleteCallCB);
1278 
1279     return Changed;
1280   }
1281 
1282   /// Try to eliminate runtime calls by reusing existing ones.
1283   bool deduplicateRuntimeCalls() {
1284     bool Changed = false;
1285 
1286     RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1287         OMPRTL_omp_get_num_threads,
1288         OMPRTL_omp_in_parallel,
1289         OMPRTL_omp_get_cancellation,
1290         OMPRTL_omp_get_thread_limit,
1291         OMPRTL_omp_get_supported_active_levels,
1292         OMPRTL_omp_get_level,
1293         OMPRTL_omp_get_ancestor_thread_num,
1294         OMPRTL_omp_get_team_size,
1295         OMPRTL_omp_get_active_level,
1296         OMPRTL_omp_in_final,
1297         OMPRTL_omp_get_proc_bind,
1298         OMPRTL_omp_get_num_places,
1299         OMPRTL_omp_get_num_procs,
1300         OMPRTL_omp_get_place_num,
1301         OMPRTL_omp_get_partition_num_places,
1302         OMPRTL_omp_get_partition_place_nums};
1303 
1304     // Global-tid is handled separately.
1305     SmallSetVector<Value *, 16> GTIdArgs;
1306     collectGlobalThreadIdArguments(GTIdArgs);
1307     LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1308                       << " global thread ID arguments\n");
1309 
1310     for (Function *F : SCC) {
1311       for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1312         Changed |= deduplicateRuntimeCalls(
1313             *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1314 
1315       // __kmpc_global_thread_num is special as we can replace it with an
1316       // argument in enough cases to make it worth trying.
1317       Value *GTIdArg = nullptr;
1318       for (Argument &Arg : F->args())
1319         if (GTIdArgs.count(&Arg)) {
1320           GTIdArg = &Arg;
1321           break;
1322         }
1323       Changed |= deduplicateRuntimeCalls(
1324           *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1325     }
1326 
1327     return Changed;
1328   }
1329 
1330   /// Tries to hide the latency of runtime calls that involve host to
1331   /// device memory transfers by splitting them into their "issue" and "wait"
1332   /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1333   /// moved downards as much as possible. The "issue" issues the memory transfer
1334   /// asynchronously, returning a handle. The "wait" waits in the returned
1335   /// handle for the memory transfer to finish.
1336   bool hideMemTransfersLatency() {
1337     auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1338     bool Changed = false;
1339     auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1340       auto *RTCall = getCallIfRegularCall(U, &RFI);
1341       if (!RTCall)
1342         return false;
1343 
1344       OffloadArray OffloadArrays[3];
1345       if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1346         return false;
1347 
1348       LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1349 
1350       // TODO: Check if can be moved upwards.
1351       bool WasSplit = false;
1352       Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1353       if (WaitMovementPoint)
1354         WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1355 
1356       Changed |= WasSplit;
1357       return WasSplit;
1358     };
1359     RFI.foreachUse(SCC, SplitMemTransfers);
1360 
1361     return Changed;
1362   }
1363 
1364   void analysisGlobalization() {
1365     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1366 
1367     auto CheckGlobalization = [&](Use &U, Function &Decl) {
1368       if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1369         auto Remark = [&](OptimizationRemarkMissed ORM) {
1370           return ORM
1371                  << "Found thread data sharing on the GPU. "
1372                  << "Expect degraded performance due to data globalization.";
1373         };
1374         emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1375       }
1376 
1377       return false;
1378     };
1379 
1380     RFI.foreachUse(SCC, CheckGlobalization);
1381   }
1382 
1383   /// Maps the values stored in the offload arrays passed as arguments to
1384   /// \p RuntimeCall into the offload arrays in \p OAs.
1385   bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1386                                 MutableArrayRef<OffloadArray> OAs) {
1387     assert(OAs.size() == 3 && "Need space for three offload arrays!");
1388 
1389     // A runtime call that involves memory offloading looks something like:
1390     // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1391     //   i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1392     // ...)
1393     // So, the idea is to access the allocas that allocate space for these
1394     // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1395     // Therefore:
1396     // i8** %offload_baseptrs.
1397     Value *BasePtrsArg =
1398         RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1399     // i8** %offload_ptrs.
1400     Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1401     // i8** %offload_sizes.
1402     Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1403 
1404     // Get values stored in **offload_baseptrs.
1405     auto *V = getUnderlyingObject(BasePtrsArg);
1406     if (!isa<AllocaInst>(V))
1407       return false;
1408     auto *BasePtrsArray = cast<AllocaInst>(V);
1409     if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1410       return false;
1411 
1412     // Get values stored in **offload_baseptrs.
1413     V = getUnderlyingObject(PtrsArg);
1414     if (!isa<AllocaInst>(V))
1415       return false;
1416     auto *PtrsArray = cast<AllocaInst>(V);
1417     if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1418       return false;
1419 
1420     // Get values stored in **offload_sizes.
1421     V = getUnderlyingObject(SizesArg);
1422     // If it's a [constant] global array don't analyze it.
1423     if (isa<GlobalValue>(V))
1424       return isa<Constant>(V);
1425     if (!isa<AllocaInst>(V))
1426       return false;
1427 
1428     auto *SizesArray = cast<AllocaInst>(V);
1429     if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1430       return false;
1431 
1432     return true;
1433   }
1434 
1435   /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1436   /// For now this is a way to test that the function getValuesInOffloadArrays
1437   /// is working properly.
1438   /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1439   void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1440     assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1441 
1442     LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1443     std::string ValuesStr;
1444     raw_string_ostream Printer(ValuesStr);
1445     std::string Separator = " --- ";
1446 
1447     for (auto *BP : OAs[0].StoredValues) {
1448       BP->print(Printer);
1449       Printer << Separator;
1450     }
1451     LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
1452     ValuesStr.clear();
1453 
1454     for (auto *P : OAs[1].StoredValues) {
1455       P->print(Printer);
1456       Printer << Separator;
1457     }
1458     LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
1459     ValuesStr.clear();
1460 
1461     for (auto *S : OAs[2].StoredValues) {
1462       S->print(Printer);
1463       Printer << Separator;
1464     }
1465     LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
1466   }
1467 
1468   /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1469   /// moved. Returns nullptr if the movement is not possible, or not worth it.
1470   Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1471     // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1472     //  Make it traverse the CFG.
1473 
1474     Instruction *CurrentI = &RuntimeCall;
1475     bool IsWorthIt = false;
1476     while ((CurrentI = CurrentI->getNextNode())) {
1477 
1478       // TODO: Once we detect the regions to be offloaded we should use the
1479       //  alias analysis manager to check if CurrentI may modify one of
1480       //  the offloaded regions.
1481       if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1482         if (IsWorthIt)
1483           return CurrentI;
1484 
1485         return nullptr;
1486       }
1487 
1488       // FIXME: For now if we move it over anything without side effect
1489       //  is worth it.
1490       IsWorthIt = true;
1491     }
1492 
1493     // Return end of BasicBlock.
1494     return RuntimeCall.getParent()->getTerminator();
1495   }
1496 
1497   /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1498   bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1499                                Instruction &WaitMovementPoint) {
1500     // Create stack allocated handle (__tgt_async_info) at the beginning of the
1501     // function. Used for storing information of the async transfer, allowing to
1502     // wait on it later.
1503     auto &IRBuilder = OMPInfoCache.OMPBuilder;
1504     auto *F = RuntimeCall.getCaller();
1505     Instruction *FirstInst = &(F->getEntryBlock().front());
1506     AllocaInst *Handle = new AllocaInst(
1507         IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst);
1508 
1509     // Add "issue" runtime call declaration:
1510     // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1511     //   i8**, i8**, i64*, i64*)
1512     FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1513         M, OMPRTL___tgt_target_data_begin_mapper_issue);
1514 
1515     // Change RuntimeCall call site for its asynchronous version.
1516     SmallVector<Value *, 16> Args;
1517     for (auto &Arg : RuntimeCall.args())
1518       Args.push_back(Arg.get());
1519     Args.push_back(Handle);
1520 
1521     CallInst *IssueCallsite =
1522         CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
1523     RuntimeCall.eraseFromParent();
1524 
1525     // Add "wait" runtime call declaration:
1526     // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1527     FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1528         M, OMPRTL___tgt_target_data_begin_mapper_wait);
1529 
1530     Value *WaitParams[2] = {
1531         IssueCallsite->getArgOperand(
1532             OffloadArray::DeviceIDArgNum), // device_id.
1533         Handle                             // handle to wait on.
1534     };
1535     CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
1536 
1537     return true;
1538   }
1539 
1540   static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1541                                     bool GlobalOnly, bool &SingleChoice) {
1542     if (CurrentIdent == NextIdent)
1543       return CurrentIdent;
1544 
1545     // TODO: Figure out how to actually combine multiple debug locations. For
1546     //       now we just keep an existing one if there is a single choice.
1547     if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1548       SingleChoice = !CurrentIdent;
1549       return NextIdent;
1550     }
1551     return nullptr;
1552   }
1553 
1554   /// Return an `struct ident_t*` value that represents the ones used in the
1555   /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1556   /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1557   /// return value we create one from scratch. We also do not yet combine
1558   /// information, e.g., the source locations, see combinedIdentStruct.
1559   Value *
1560   getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1561                                  Function &F, bool GlobalOnly) {
1562     bool SingleChoice = true;
1563     Value *Ident = nullptr;
1564     auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1565       CallInst *CI = getCallIfRegularCall(U, &RFI);
1566       if (!CI || &F != &Caller)
1567         return false;
1568       Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1569                                   /* GlobalOnly */ true, SingleChoice);
1570       return false;
1571     };
1572     RFI.foreachUse(SCC, CombineIdentStruct);
1573 
1574     if (!Ident || !SingleChoice) {
1575       // The IRBuilder uses the insertion block to get to the module, this is
1576       // unfortunate but we work around it for now.
1577       if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1578         OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1579             &F.getEntryBlock(), F.getEntryBlock().begin()));
1580       // Create a fallback location if non was found.
1581       // TODO: Use the debug locations of the calls instead.
1582       Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr();
1583       Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc);
1584     }
1585     return Ident;
1586   }
1587 
1588   /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1589   /// \p ReplVal if given.
1590   bool deduplicateRuntimeCalls(Function &F,
1591                                OMPInformationCache::RuntimeFunctionInfo &RFI,
1592                                Value *ReplVal = nullptr) {
1593     auto *UV = RFI.getUseVector(F);
1594     if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1595       return false;
1596 
1597     LLVM_DEBUG(
1598         dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1599                << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1600 
1601     assert((!ReplVal || (isa<Argument>(ReplVal) &&
1602                          cast<Argument>(ReplVal)->getParent() == &F)) &&
1603            "Unexpected replacement value!");
1604 
1605     // TODO: Use dominance to find a good position instead.
1606     auto CanBeMoved = [this](CallBase &CB) {
1607       unsigned NumArgs = CB.getNumArgOperands();
1608       if (NumArgs == 0)
1609         return true;
1610       if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1611         return false;
1612       for (unsigned u = 1; u < NumArgs; ++u)
1613         if (isa<Instruction>(CB.getArgOperand(u)))
1614           return false;
1615       return true;
1616     };
1617 
1618     if (!ReplVal) {
1619       for (Use *U : *UV)
1620         if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1621           if (!CanBeMoved(*CI))
1622             continue;
1623 
1624           // If the function is a kernel, dedup will move
1625           // the runtime call right after the kernel init callsite. Otherwise,
1626           // it will move it to the beginning of the caller function.
1627           if (isKernel(F)) {
1628             auto &KernelInitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
1629             auto *KernelInitUV = KernelInitRFI.getUseVector(F);
1630 
1631             if (KernelInitUV->empty())
1632               continue;
1633 
1634             assert(KernelInitUV->size() == 1 &&
1635                    "Expected a single __kmpc_target_init in kernel\n");
1636 
1637             CallInst *KernelInitCI =
1638                 getCallIfRegularCall(*KernelInitUV->front(), &KernelInitRFI);
1639             assert(KernelInitCI &&
1640                    "Expected a call to __kmpc_target_init in kernel\n");
1641 
1642             CI->moveAfter(KernelInitCI);
1643           } else
1644             CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
1645           ReplVal = CI;
1646           break;
1647         }
1648       if (!ReplVal)
1649         return false;
1650     }
1651 
1652     // If we use a call as a replacement value we need to make sure the ident is
1653     // valid at the new location. For now we just pick a global one, either
1654     // existing and used by one of the calls, or created from scratch.
1655     if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1656       if (!CI->arg_empty() &&
1657           CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1658         Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1659                                                       /* GlobalOnly */ true);
1660         CI->setArgOperand(0, Ident);
1661       }
1662     }
1663 
1664     bool Changed = false;
1665     auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1666       CallInst *CI = getCallIfRegularCall(U, &RFI);
1667       if (!CI || CI == ReplVal || &F != &Caller)
1668         return false;
1669       assert(CI->getCaller() == &F && "Unexpected call!");
1670 
1671       auto Remark = [&](OptimizationRemark OR) {
1672         return OR << "OpenMP runtime call "
1673                   << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1674       };
1675       if (CI->getDebugLoc())
1676         emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1677       else
1678         emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1679 
1680       CGUpdater.removeCallSite(*CI);
1681       CI->replaceAllUsesWith(ReplVal);
1682       CI->eraseFromParent();
1683       ++NumOpenMPRuntimeCallsDeduplicated;
1684       Changed = true;
1685       return true;
1686     };
1687     RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1688 
1689     return Changed;
1690   }
1691 
1692   /// Collect arguments that represent the global thread id in \p GTIdArgs.
1693   void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1694     // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1695     //       initialization. We could define an AbstractAttribute instead and
1696     //       run the Attributor here once it can be run as an SCC pass.
1697 
1698     // Helper to check the argument \p ArgNo at all call sites of \p F for
1699     // a GTId.
1700     auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1701       if (!F.hasLocalLinkage())
1702         return false;
1703       for (Use &U : F.uses()) {
1704         if (CallInst *CI = getCallIfRegularCall(U)) {
1705           Value *ArgOp = CI->getArgOperand(ArgNo);
1706           if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1707               getCallIfRegularCall(
1708                   *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1709             continue;
1710         }
1711         return false;
1712       }
1713       return true;
1714     };
1715 
1716     // Helper to identify uses of a GTId as GTId arguments.
1717     auto AddUserArgs = [&](Value &GTId) {
1718       for (Use &U : GTId.uses())
1719         if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1720           if (CI->isArgOperand(&U))
1721             if (Function *Callee = CI->getCalledFunction())
1722               if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1723                 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1724     };
1725 
1726     // The argument users of __kmpc_global_thread_num calls are GTIds.
1727     OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1728         OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1729 
1730     GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1731       if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1732         AddUserArgs(*CI);
1733       return false;
1734     });
1735 
1736     // Transitively search for more arguments by looking at the users of the
1737     // ones we know already. During the search the GTIdArgs vector is extended
1738     // so we cannot cache the size nor can we use a range based for.
1739     for (unsigned u = 0; u < GTIdArgs.size(); ++u)
1740       AddUserArgs(*GTIdArgs[u]);
1741   }
1742 
1743   /// Kernel (=GPU) optimizations and utility functions
1744   ///
1745   ///{{
1746 
1747   /// Check if \p F is a kernel, hence entry point for target offloading.
1748   bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
1749 
1750   /// Cache to remember the unique kernel for a function.
1751   DenseMap<Function *, Optional<Kernel>> UniqueKernelMap;
1752 
1753   /// Find the unique kernel that will execute \p F, if any.
1754   Kernel getUniqueKernelFor(Function &F);
1755 
1756   /// Find the unique kernel that will execute \p I, if any.
1757   Kernel getUniqueKernelFor(Instruction &I) {
1758     return getUniqueKernelFor(*I.getFunction());
1759   }
1760 
1761   /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1762   /// the cases we can avoid taking the address of a function.
1763   bool rewriteDeviceCodeStateMachine();
1764 
1765   ///
1766   ///}}
1767 
1768   /// Emit a remark generically
1769   ///
1770   /// This template function can be used to generically emit a remark. The
1771   /// RemarkKind should be one of the following:
1772   ///   - OptimizationRemark to indicate a successful optimization attempt
1773   ///   - OptimizationRemarkMissed to report a failed optimization attempt
1774   ///   - OptimizationRemarkAnalysis to provide additional information about an
1775   ///     optimization attempt
1776   ///
1777   /// The remark is built using a callback function provided by the caller that
1778   /// takes a RemarkKind as input and returns a RemarkKind.
1779   template <typename RemarkKind, typename RemarkCallBack>
1780   void emitRemark(Instruction *I, StringRef RemarkName,
1781                   RemarkCallBack &&RemarkCB) const {
1782     Function *F = I->getParent()->getParent();
1783     auto &ORE = OREGetter(F);
1784 
1785     if (RemarkName.startswith("OMP"))
1786       ORE.emit([&]() {
1787         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
1788                << " [" << RemarkName << "]";
1789       });
1790     else
1791       ORE.emit(
1792           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
1793   }
1794 
1795   /// Emit a remark on a function.
1796   template <typename RemarkKind, typename RemarkCallBack>
1797   void emitRemark(Function *F, StringRef RemarkName,
1798                   RemarkCallBack &&RemarkCB) const {
1799     auto &ORE = OREGetter(F);
1800 
1801     if (RemarkName.startswith("OMP"))
1802       ORE.emit([&]() {
1803         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
1804                << " [" << RemarkName << "]";
1805       });
1806     else
1807       ORE.emit(
1808           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
1809   }
1810 
1811   /// RAII struct to temporarily change an RTL function's linkage to external.
1812   /// This prevents it from being mistakenly removed by other optimizations.
1813   struct ExternalizationRAII {
1814     ExternalizationRAII(OMPInformationCache &OMPInfoCache,
1815                         RuntimeFunction RFKind)
1816         : Declaration(OMPInfoCache.RFIs[RFKind].Declaration) {
1817       if (!Declaration)
1818         return;
1819 
1820       LinkageType = Declaration->getLinkage();
1821       Declaration->setLinkage(GlobalValue::ExternalLinkage);
1822     }
1823 
1824     ~ExternalizationRAII() {
1825       if (!Declaration)
1826         return;
1827 
1828       Declaration->setLinkage(LinkageType);
1829     }
1830 
1831     Function *Declaration;
1832     GlobalValue::LinkageTypes LinkageType;
1833   };
1834 
1835   /// The underlying module.
1836   Module &M;
1837 
1838   /// The SCC we are operating on.
1839   SmallVectorImpl<Function *> &SCC;
1840 
1841   /// Callback to update the call graph, the first argument is a removed call,
1842   /// the second an optional replacement call.
1843   CallGraphUpdater &CGUpdater;
1844 
1845   /// Callback to get an OptimizationRemarkEmitter from a Function *
1846   OptimizationRemarkGetter OREGetter;
1847 
1848   /// OpenMP-specific information cache. Also Used for Attributor runs.
1849   OMPInformationCache &OMPInfoCache;
1850 
1851   /// Attributor instance.
1852   Attributor &A;
1853 
1854   /// Helper function to run Attributor on SCC.
1855   bool runAttributor(bool IsModulePass) {
1856     if (SCC.empty())
1857       return false;
1858 
1859     // Temporarily make these function have external linkage so the Attributor
1860     // doesn't remove them when we try to look them up later.
1861     ExternalizationRAII Parallel(OMPInfoCache, OMPRTL___kmpc_kernel_parallel);
1862     ExternalizationRAII EndParallel(OMPInfoCache,
1863                                     OMPRTL___kmpc_kernel_end_parallel);
1864     ExternalizationRAII BarrierSPMD(OMPInfoCache,
1865                                     OMPRTL___kmpc_barrier_simple_spmd);
1866 
1867     registerAAs(IsModulePass);
1868 
1869     ChangeStatus Changed = A.run();
1870 
1871     LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
1872                       << " functions, result: " << Changed << ".\n");
1873 
1874     return Changed == ChangeStatus::CHANGED;
1875   }
1876 
1877   void registerFoldRuntimeCall(RuntimeFunction RF);
1878 
1879   /// Populate the Attributor with abstract attribute opportunities in the
1880   /// function.
1881   void registerAAs(bool IsModulePass);
1882 };
1883 
1884 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
1885   if (!OMPInfoCache.ModuleSlice.count(&F))
1886     return nullptr;
1887 
1888   // Use a scope to keep the lifetime of the CachedKernel short.
1889   {
1890     Optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
1891     if (CachedKernel)
1892       return *CachedKernel;
1893 
1894     // TODO: We should use an AA to create an (optimistic and callback
1895     //       call-aware) call graph. For now we stick to simple patterns that
1896     //       are less powerful, basically the worst fixpoint.
1897     if (isKernel(F)) {
1898       CachedKernel = Kernel(&F);
1899       return *CachedKernel;
1900     }
1901 
1902     CachedKernel = nullptr;
1903     if (!F.hasLocalLinkage()) {
1904 
1905       // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
1906       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1907         return ORA << "Potentially unknown OpenMP target region caller.";
1908       };
1909       emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
1910 
1911       return nullptr;
1912     }
1913   }
1914 
1915   auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
1916     if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
1917       // Allow use in equality comparisons.
1918       if (Cmp->isEquality())
1919         return getUniqueKernelFor(*Cmp);
1920       return nullptr;
1921     }
1922     if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
1923       // Allow direct calls.
1924       if (CB->isCallee(&U))
1925         return getUniqueKernelFor(*CB);
1926 
1927       OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1928           OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1929       // Allow the use in __kmpc_parallel_51 calls.
1930       if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
1931         return getUniqueKernelFor(*CB);
1932       return nullptr;
1933     }
1934     // Disallow every other use.
1935     return nullptr;
1936   };
1937 
1938   // TODO: In the future we want to track more than just a unique kernel.
1939   SmallPtrSet<Kernel, 2> PotentialKernels;
1940   OMPInformationCache::foreachUse(F, [&](const Use &U) {
1941     PotentialKernels.insert(GetUniqueKernelForUse(U));
1942   });
1943 
1944   Kernel K = nullptr;
1945   if (PotentialKernels.size() == 1)
1946     K = *PotentialKernels.begin();
1947 
1948   // Cache the result.
1949   UniqueKernelMap[&F] = K;
1950 
1951   return K;
1952 }
1953 
1954 bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
1955   OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1956       OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1957 
1958   bool Changed = false;
1959   if (!KernelParallelRFI)
1960     return Changed;
1961 
1962   // If we have disabled state machine changes, exit
1963   if (DisableOpenMPOptStateMachineRewrite)
1964     return Changed;
1965 
1966   for (Function *F : SCC) {
1967 
1968     // Check if the function is a use in a __kmpc_parallel_51 call at
1969     // all.
1970     bool UnknownUse = false;
1971     bool KernelParallelUse = false;
1972     unsigned NumDirectCalls = 0;
1973 
1974     SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
1975     OMPInformationCache::foreachUse(*F, [&](Use &U) {
1976       if (auto *CB = dyn_cast<CallBase>(U.getUser()))
1977         if (CB->isCallee(&U)) {
1978           ++NumDirectCalls;
1979           return;
1980         }
1981 
1982       if (isa<ICmpInst>(U.getUser())) {
1983         ToBeReplacedStateMachineUses.push_back(&U);
1984         return;
1985       }
1986 
1987       // Find wrapper functions that represent parallel kernels.
1988       CallInst *CI =
1989           OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
1990       const unsigned int WrapperFunctionArgNo = 6;
1991       if (!KernelParallelUse && CI &&
1992           CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
1993         KernelParallelUse = true;
1994         ToBeReplacedStateMachineUses.push_back(&U);
1995         return;
1996       }
1997       UnknownUse = true;
1998     });
1999 
2000     // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2001     // use.
2002     if (!KernelParallelUse)
2003       continue;
2004 
2005     // If this ever hits, we should investigate.
2006     // TODO: Checking the number of uses is not a necessary restriction and
2007     // should be lifted.
2008     if (UnknownUse || NumDirectCalls != 1 ||
2009         ToBeReplacedStateMachineUses.size() > 2) {
2010       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2011         return ORA << "Parallel region is used in "
2012                    << (UnknownUse ? "unknown" : "unexpected")
2013                    << " ways. Will not attempt to rewrite the state machine.";
2014       };
2015       emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2016       continue;
2017     }
2018 
2019     // Even if we have __kmpc_parallel_51 calls, we (for now) give
2020     // up if the function is not called from a unique kernel.
2021     Kernel K = getUniqueKernelFor(*F);
2022     if (!K) {
2023       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2024         return ORA << "Parallel region is not called from a unique kernel. "
2025                       "Will not attempt to rewrite the state machine.";
2026       };
2027       emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2028       continue;
2029     }
2030 
2031     // We now know F is a parallel body function called only from the kernel K.
2032     // We also identified the state machine uses in which we replace the
2033     // function pointer by a new global symbol for identification purposes. This
2034     // ensures only direct calls to the function are left.
2035 
2036     Module &M = *F->getParent();
2037     Type *Int8Ty = Type::getInt8Ty(M.getContext());
2038 
2039     auto *ID = new GlobalVariable(
2040         M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2041         UndefValue::get(Int8Ty), F->getName() + ".ID");
2042 
2043     for (Use *U : ToBeReplacedStateMachineUses)
2044       U->set(ConstantExpr::getBitCast(ID, U->get()->getType()));
2045 
2046     ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2047 
2048     Changed = true;
2049   }
2050 
2051   return Changed;
2052 }
2053 
2054 /// Abstract Attribute for tracking ICV values.
2055 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2056   using Base = StateWrapper<BooleanState, AbstractAttribute>;
2057   AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2058 
2059   void initialize(Attributor &A) override {
2060     Function *F = getAnchorScope();
2061     if (!F || !A.isFunctionIPOAmendable(*F))
2062       indicatePessimisticFixpoint();
2063   }
2064 
2065   /// Returns true if value is assumed to be tracked.
2066   bool isAssumedTracked() const { return getAssumed(); }
2067 
2068   /// Returns true if value is known to be tracked.
2069   bool isKnownTracked() const { return getAssumed(); }
2070 
2071   /// Create an abstract attribute biew for the position \p IRP.
2072   static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2073 
2074   /// Return the value with which \p I can be replaced for specific \p ICV.
2075   virtual Optional<Value *> getReplacementValue(InternalControlVar ICV,
2076                                                 const Instruction *I,
2077                                                 Attributor &A) const {
2078     return None;
2079   }
2080 
2081   /// Return an assumed unique ICV value if a single candidate is found. If
2082   /// there cannot be one, return a nullptr. If it is not clear yet, return the
2083   /// Optional::NoneType.
2084   virtual Optional<Value *>
2085   getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2086 
2087   // Currently only nthreads is being tracked.
2088   // this array will only grow with time.
2089   InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2090 
2091   /// See AbstractAttribute::getName()
2092   const std::string getName() const override { return "AAICVTracker"; }
2093 
2094   /// See AbstractAttribute::getIdAddr()
2095   const char *getIdAddr() const override { return &ID; }
2096 
2097   /// This function should return true if the type of the \p AA is AAICVTracker
2098   static bool classof(const AbstractAttribute *AA) {
2099     return (AA->getIdAddr() == &ID);
2100   }
2101 
2102   static const char ID;
2103 };
2104 
2105 struct AAICVTrackerFunction : public AAICVTracker {
2106   AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2107       : AAICVTracker(IRP, A) {}
2108 
2109   // FIXME: come up with better string.
2110   const std::string getAsStr() const override { return "ICVTrackerFunction"; }
2111 
2112   // FIXME: come up with some stats.
2113   void trackStatistics() const override {}
2114 
2115   /// We don't manifest anything for this AA.
2116   ChangeStatus manifest(Attributor &A) override {
2117     return ChangeStatus::UNCHANGED;
2118   }
2119 
2120   // Map of ICV to their values at specific program point.
2121   EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
2122                   InternalControlVar::ICV___last>
2123       ICVReplacementValuesMap;
2124 
2125   ChangeStatus updateImpl(Attributor &A) override {
2126     ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2127 
2128     Function *F = getAnchorScope();
2129 
2130     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2131 
2132     for (InternalControlVar ICV : TrackableICVs) {
2133       auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2134 
2135       auto &ValuesMap = ICVReplacementValuesMap[ICV];
2136       auto TrackValues = [&](Use &U, Function &) {
2137         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2138         if (!CI)
2139           return false;
2140 
2141         // FIXME: handle setters with more that 1 arguments.
2142         /// Track new value.
2143         if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2144           HasChanged = ChangeStatus::CHANGED;
2145 
2146         return false;
2147       };
2148 
2149       auto CallCheck = [&](Instruction &I) {
2150         Optional<Value *> ReplVal = getValueForCall(A, &I, ICV);
2151         if (ReplVal.hasValue() &&
2152             ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2153           HasChanged = ChangeStatus::CHANGED;
2154 
2155         return true;
2156       };
2157 
2158       // Track all changes of an ICV.
2159       SetterRFI.foreachUse(TrackValues, F);
2160 
2161       bool UsedAssumedInformation = false;
2162       A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2163                                 UsedAssumedInformation,
2164                                 /* CheckBBLivenessOnly */ true);
2165 
2166       /// TODO: Figure out a way to avoid adding entry in
2167       /// ICVReplacementValuesMap
2168       Instruction *Entry = &F->getEntryBlock().front();
2169       if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2170         ValuesMap.insert(std::make_pair(Entry, nullptr));
2171     }
2172 
2173     return HasChanged;
2174   }
2175 
2176   /// Hepler to check if \p I is a call and get the value for it if it is
2177   /// unique.
2178   Optional<Value *> getValueForCall(Attributor &A, const Instruction *I,
2179                                     InternalControlVar &ICV) const {
2180 
2181     const auto *CB = dyn_cast<CallBase>(I);
2182     if (!CB || CB->hasFnAttr("no_openmp") ||
2183         CB->hasFnAttr("no_openmp_routines"))
2184       return None;
2185 
2186     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2187     auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2188     auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2189     Function *CalledFunction = CB->getCalledFunction();
2190 
2191     // Indirect call, assume ICV changes.
2192     if (CalledFunction == nullptr)
2193       return nullptr;
2194     if (CalledFunction == GetterRFI.Declaration)
2195       return None;
2196     if (CalledFunction == SetterRFI.Declaration) {
2197       if (ICVReplacementValuesMap[ICV].count(I))
2198         return ICVReplacementValuesMap[ICV].lookup(I);
2199 
2200       return nullptr;
2201     }
2202 
2203     // Since we don't know, assume it changes the ICV.
2204     if (CalledFunction->isDeclaration())
2205       return nullptr;
2206 
2207     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2208         *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2209 
2210     if (ICVTrackingAA.isAssumedTracked())
2211       return ICVTrackingAA.getUniqueReplacementValue(ICV);
2212 
2213     // If we don't know, assume it changes.
2214     return nullptr;
2215   }
2216 
2217   // We don't check unique value for a function, so return None.
2218   Optional<Value *>
2219   getUniqueReplacementValue(InternalControlVar ICV) const override {
2220     return None;
2221   }
2222 
2223   /// Return the value with which \p I can be replaced for specific \p ICV.
2224   Optional<Value *> getReplacementValue(InternalControlVar ICV,
2225                                         const Instruction *I,
2226                                         Attributor &A) const override {
2227     const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2228     if (ValuesMap.count(I))
2229       return ValuesMap.lookup(I);
2230 
2231     SmallVector<const Instruction *, 16> Worklist;
2232     SmallPtrSet<const Instruction *, 16> Visited;
2233     Worklist.push_back(I);
2234 
2235     Optional<Value *> ReplVal;
2236 
2237     while (!Worklist.empty()) {
2238       const Instruction *CurrInst = Worklist.pop_back_val();
2239       if (!Visited.insert(CurrInst).second)
2240         continue;
2241 
2242       const BasicBlock *CurrBB = CurrInst->getParent();
2243 
2244       // Go up and look for all potential setters/calls that might change the
2245       // ICV.
2246       while ((CurrInst = CurrInst->getPrevNode())) {
2247         if (ValuesMap.count(CurrInst)) {
2248           Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2249           // Unknown value, track new.
2250           if (!ReplVal.hasValue()) {
2251             ReplVal = NewReplVal;
2252             break;
2253           }
2254 
2255           // If we found a new value, we can't know the icv value anymore.
2256           if (NewReplVal.hasValue())
2257             if (ReplVal != NewReplVal)
2258               return nullptr;
2259 
2260           break;
2261         }
2262 
2263         Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV);
2264         if (!NewReplVal.hasValue())
2265           continue;
2266 
2267         // Unknown value, track new.
2268         if (!ReplVal.hasValue()) {
2269           ReplVal = NewReplVal;
2270           break;
2271         }
2272 
2273         // if (NewReplVal.hasValue())
2274         // We found a new value, we can't know the icv value anymore.
2275         if (ReplVal != NewReplVal)
2276           return nullptr;
2277       }
2278 
2279       // If we are in the same BB and we have a value, we are done.
2280       if (CurrBB == I->getParent() && ReplVal.hasValue())
2281         return ReplVal;
2282 
2283       // Go through all predecessors and add terminators for analysis.
2284       for (const BasicBlock *Pred : predecessors(CurrBB))
2285         if (const Instruction *Terminator = Pred->getTerminator())
2286           Worklist.push_back(Terminator);
2287     }
2288 
2289     return ReplVal;
2290   }
2291 };
2292 
2293 struct AAICVTrackerFunctionReturned : AAICVTracker {
2294   AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2295       : AAICVTracker(IRP, A) {}
2296 
2297   // FIXME: come up with better string.
2298   const std::string getAsStr() const override {
2299     return "ICVTrackerFunctionReturned";
2300   }
2301 
2302   // FIXME: come up with some stats.
2303   void trackStatistics() const override {}
2304 
2305   /// We don't manifest anything for this AA.
2306   ChangeStatus manifest(Attributor &A) override {
2307     return ChangeStatus::UNCHANGED;
2308   }
2309 
2310   // Map of ICV to their values at specific program point.
2311   EnumeratedArray<Optional<Value *>, InternalControlVar,
2312                   InternalControlVar::ICV___last>
2313       ICVReplacementValuesMap;
2314 
2315   /// Return the value with which \p I can be replaced for specific \p ICV.
2316   Optional<Value *>
2317   getUniqueReplacementValue(InternalControlVar ICV) const override {
2318     return ICVReplacementValuesMap[ICV];
2319   }
2320 
2321   ChangeStatus updateImpl(Attributor &A) override {
2322     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2323     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2324         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2325 
2326     if (!ICVTrackingAA.isAssumedTracked())
2327       return indicatePessimisticFixpoint();
2328 
2329     for (InternalControlVar ICV : TrackableICVs) {
2330       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2331       Optional<Value *> UniqueICVValue;
2332 
2333       auto CheckReturnInst = [&](Instruction &I) {
2334         Optional<Value *> NewReplVal =
2335             ICVTrackingAA.getReplacementValue(ICV, &I, A);
2336 
2337         // If we found a second ICV value there is no unique returned value.
2338         if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal)
2339           return false;
2340 
2341         UniqueICVValue = NewReplVal;
2342 
2343         return true;
2344       };
2345 
2346       bool UsedAssumedInformation = false;
2347       if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2348                                      UsedAssumedInformation,
2349                                      /* CheckBBLivenessOnly */ true))
2350         UniqueICVValue = nullptr;
2351 
2352       if (UniqueICVValue == ReplVal)
2353         continue;
2354 
2355       ReplVal = UniqueICVValue;
2356       Changed = ChangeStatus::CHANGED;
2357     }
2358 
2359     return Changed;
2360   }
2361 };
2362 
2363 struct AAICVTrackerCallSite : AAICVTracker {
2364   AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2365       : AAICVTracker(IRP, A) {}
2366 
2367   void initialize(Attributor &A) override {
2368     Function *F = getAnchorScope();
2369     if (!F || !A.isFunctionIPOAmendable(*F))
2370       indicatePessimisticFixpoint();
2371 
2372     // We only initialize this AA for getters, so we need to know which ICV it
2373     // gets.
2374     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2375     for (InternalControlVar ICV : TrackableICVs) {
2376       auto ICVInfo = OMPInfoCache.ICVs[ICV];
2377       auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2378       if (Getter.Declaration == getAssociatedFunction()) {
2379         AssociatedICV = ICVInfo.Kind;
2380         return;
2381       }
2382     }
2383 
2384     /// Unknown ICV.
2385     indicatePessimisticFixpoint();
2386   }
2387 
2388   ChangeStatus manifest(Attributor &A) override {
2389     if (!ReplVal.hasValue() || !ReplVal.getValue())
2390       return ChangeStatus::UNCHANGED;
2391 
2392     A.changeValueAfterManifest(*getCtxI(), **ReplVal);
2393     A.deleteAfterManifest(*getCtxI());
2394 
2395     return ChangeStatus::CHANGED;
2396   }
2397 
2398   // FIXME: come up with better string.
2399   const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
2400 
2401   // FIXME: come up with some stats.
2402   void trackStatistics() const override {}
2403 
2404   InternalControlVar AssociatedICV;
2405   Optional<Value *> ReplVal;
2406 
2407   ChangeStatus updateImpl(Attributor &A) override {
2408     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2409         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2410 
2411     // We don't have any information, so we assume it changes the ICV.
2412     if (!ICVTrackingAA.isAssumedTracked())
2413       return indicatePessimisticFixpoint();
2414 
2415     Optional<Value *> NewReplVal =
2416         ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
2417 
2418     if (ReplVal == NewReplVal)
2419       return ChangeStatus::UNCHANGED;
2420 
2421     ReplVal = NewReplVal;
2422     return ChangeStatus::CHANGED;
2423   }
2424 
2425   // Return the value with which associated value can be replaced for specific
2426   // \p ICV.
2427   Optional<Value *>
2428   getUniqueReplacementValue(InternalControlVar ICV) const override {
2429     return ReplVal;
2430   }
2431 };
2432 
2433 struct AAICVTrackerCallSiteReturned : AAICVTracker {
2434   AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2435       : AAICVTracker(IRP, A) {}
2436 
2437   // FIXME: come up with better string.
2438   const std::string getAsStr() const override {
2439     return "ICVTrackerCallSiteReturned";
2440   }
2441 
2442   // FIXME: come up with some stats.
2443   void trackStatistics() const override {}
2444 
2445   /// We don't manifest anything for this AA.
2446   ChangeStatus manifest(Attributor &A) override {
2447     return ChangeStatus::UNCHANGED;
2448   }
2449 
2450   // Map of ICV to their values at specific program point.
2451   EnumeratedArray<Optional<Value *>, InternalControlVar,
2452                   InternalControlVar::ICV___last>
2453       ICVReplacementValuesMap;
2454 
2455   /// Return the value with which associated value can be replaced for specific
2456   /// \p ICV.
2457   Optional<Value *>
2458   getUniqueReplacementValue(InternalControlVar ICV) const override {
2459     return ICVReplacementValuesMap[ICV];
2460   }
2461 
2462   ChangeStatus updateImpl(Attributor &A) override {
2463     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2464     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2465         *this, IRPosition::returned(*getAssociatedFunction()),
2466         DepClassTy::REQUIRED);
2467 
2468     // We don't have any information, so we assume it changes the ICV.
2469     if (!ICVTrackingAA.isAssumedTracked())
2470       return indicatePessimisticFixpoint();
2471 
2472     for (InternalControlVar ICV : TrackableICVs) {
2473       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2474       Optional<Value *> NewReplVal =
2475           ICVTrackingAA.getUniqueReplacementValue(ICV);
2476 
2477       if (ReplVal == NewReplVal)
2478         continue;
2479 
2480       ReplVal = NewReplVal;
2481       Changed = ChangeStatus::CHANGED;
2482     }
2483     return Changed;
2484   }
2485 };
2486 
2487 struct AAExecutionDomainFunction : public AAExecutionDomain {
2488   AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2489       : AAExecutionDomain(IRP, A) {}
2490 
2491   const std::string getAsStr() const override {
2492     return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) +
2493            "/" + std::to_string(NumBBs) + " BBs thread 0 only.";
2494   }
2495 
2496   /// See AbstractAttribute::trackStatistics().
2497   void trackStatistics() const override {}
2498 
2499   void initialize(Attributor &A) override {
2500     Function *F = getAnchorScope();
2501     for (const auto &BB : *F)
2502       SingleThreadedBBs.insert(&BB);
2503     NumBBs = SingleThreadedBBs.size();
2504   }
2505 
2506   ChangeStatus manifest(Attributor &A) override {
2507     LLVM_DEBUG({
2508       for (const BasicBlock *BB : SingleThreadedBBs)
2509         dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2510                << BB->getName() << " is executed by a single thread.\n";
2511     });
2512     return ChangeStatus::UNCHANGED;
2513   }
2514 
2515   ChangeStatus updateImpl(Attributor &A) override;
2516 
2517   /// Check if an instruction is executed by a single thread.
2518   bool isExecutedByInitialThreadOnly(const Instruction &I) const override {
2519     return isExecutedByInitialThreadOnly(*I.getParent());
2520   }
2521 
2522   bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2523     return isValidState() && SingleThreadedBBs.contains(&BB);
2524   }
2525 
2526   /// Set of basic blocks that are executed by a single thread.
2527   DenseSet<const BasicBlock *> SingleThreadedBBs;
2528 
2529   /// Total number of basic blocks in this function.
2530   long unsigned NumBBs;
2531 };
2532 
2533 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
2534   Function *F = getAnchorScope();
2535   ReversePostOrderTraversal<Function *> RPOT(F);
2536   auto NumSingleThreadedBBs = SingleThreadedBBs.size();
2537 
2538   bool AllCallSitesKnown;
2539   auto PredForCallSite = [&](AbstractCallSite ACS) {
2540     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
2541         *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
2542         DepClassTy::REQUIRED);
2543     return ACS.isDirectCall() &&
2544            ExecutionDomainAA.isExecutedByInitialThreadOnly(
2545                *ACS.getInstruction());
2546   };
2547 
2548   if (!A.checkForAllCallSites(PredForCallSite, *this,
2549                               /* RequiresAllCallSites */ true,
2550                               AllCallSitesKnown))
2551     SingleThreadedBBs.erase(&F->getEntryBlock());
2552 
2553   auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2554   auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2555 
2556   // Check if the edge into the successor block compares the __kmpc_target_init
2557   // result with -1. If we are in non-SPMD-mode that signals only the main
2558   // thread will execute the edge.
2559   auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
2560     if (!Edge || !Edge->isConditional())
2561       return false;
2562     if (Edge->getSuccessor(0) != SuccessorBB)
2563       return false;
2564 
2565     auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2566     if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2567       return false;
2568 
2569     ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2570     if (!C)
2571       return false;
2572 
2573     // Match:  -1 == __kmpc_target_init (for non-SPMD kernels only!)
2574     if (C->isAllOnesValue()) {
2575       auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2576       CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2577       if (!CB)
2578         return false;
2579       const int InitIsSPMDArgNo = 1;
2580       auto *IsSPMDModeCI =
2581           dyn_cast<ConstantInt>(CB->getOperand(InitIsSPMDArgNo));
2582       return IsSPMDModeCI && IsSPMDModeCI->isZero();
2583     }
2584 
2585     return false;
2586   };
2587 
2588   // Merge all the predecessor states into the current basic block. A basic
2589   // block is executed by a single thread if all of its predecessors are.
2590   auto MergePredecessorStates = [&](BasicBlock *BB) {
2591     if (pred_begin(BB) == pred_end(BB))
2592       return SingleThreadedBBs.contains(BB);
2593 
2594     bool IsInitialThread = true;
2595     for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB);
2596          PredBB != PredEndBB; ++PredBB) {
2597       if (!IsInitialThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()),
2598                                BB))
2599         IsInitialThread &= SingleThreadedBBs.contains(*PredBB);
2600     }
2601 
2602     return IsInitialThread;
2603   };
2604 
2605   for (auto *BB : RPOT) {
2606     if (!MergePredecessorStates(BB))
2607       SingleThreadedBBs.erase(BB);
2608   }
2609 
2610   return (NumSingleThreadedBBs == SingleThreadedBBs.size())
2611              ? ChangeStatus::UNCHANGED
2612              : ChangeStatus::CHANGED;
2613 }
2614 
2615 /// Try to replace memory allocation calls called by a single thread with a
2616 /// static buffer of shared memory.
2617 struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
2618   using Base = StateWrapper<BooleanState, AbstractAttribute>;
2619   AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2620 
2621   /// Create an abstract attribute view for the position \p IRP.
2622   static AAHeapToShared &createForPosition(const IRPosition &IRP,
2623                                            Attributor &A);
2624 
2625   /// Returns true if HeapToShared conversion is assumed to be possible.
2626   virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
2627 
2628   /// Returns true if HeapToShared conversion is assumed and the CB is a
2629   /// callsite to a free operation to be removed.
2630   virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
2631 
2632   /// See AbstractAttribute::getName().
2633   const std::string getName() const override { return "AAHeapToShared"; }
2634 
2635   /// See AbstractAttribute::getIdAddr().
2636   const char *getIdAddr() const override { return &ID; }
2637 
2638   /// This function should return true if the type of the \p AA is
2639   /// AAHeapToShared.
2640   static bool classof(const AbstractAttribute *AA) {
2641     return (AA->getIdAddr() == &ID);
2642   }
2643 
2644   /// Unique ID (due to the unique address)
2645   static const char ID;
2646 };
2647 
2648 struct AAHeapToSharedFunction : public AAHeapToShared {
2649   AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
2650       : AAHeapToShared(IRP, A) {}
2651 
2652   const std::string getAsStr() const override {
2653     return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
2654            " malloc calls eligible.";
2655   }
2656 
2657   /// See AbstractAttribute::trackStatistics().
2658   void trackStatistics() const override {}
2659 
2660   /// This functions finds free calls that will be removed by the
2661   /// HeapToShared transformation.
2662   void findPotentialRemovedFreeCalls(Attributor &A) {
2663     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2664     auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2665 
2666     PotentialRemovedFreeCalls.clear();
2667     // Update free call users of found malloc calls.
2668     for (CallBase *CB : MallocCalls) {
2669       SmallVector<CallBase *, 4> FreeCalls;
2670       for (auto *U : CB->users()) {
2671         CallBase *C = dyn_cast<CallBase>(U);
2672         if (C && C->getCalledFunction() == FreeRFI.Declaration)
2673           FreeCalls.push_back(C);
2674       }
2675 
2676       if (FreeCalls.size() != 1)
2677         continue;
2678 
2679       PotentialRemovedFreeCalls.insert(FreeCalls.front());
2680     }
2681   }
2682 
2683   void initialize(Attributor &A) override {
2684     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2685     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2686 
2687     for (User *U : RFI.Declaration->users())
2688       if (CallBase *CB = dyn_cast<CallBase>(U))
2689         MallocCalls.insert(CB);
2690 
2691     findPotentialRemovedFreeCalls(A);
2692   }
2693 
2694   bool isAssumedHeapToShared(CallBase &CB) const override {
2695     return isValidState() && MallocCalls.count(&CB);
2696   }
2697 
2698   bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
2699     return isValidState() && PotentialRemovedFreeCalls.count(&CB);
2700   }
2701 
2702   ChangeStatus manifest(Attributor &A) override {
2703     if (MallocCalls.empty())
2704       return ChangeStatus::UNCHANGED;
2705 
2706     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2707     auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2708 
2709     Function *F = getAnchorScope();
2710     auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
2711                                             DepClassTy::OPTIONAL);
2712 
2713     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2714     for (CallBase *CB : MallocCalls) {
2715       // Skip replacing this if HeapToStack has already claimed it.
2716       if (HS && HS->isAssumedHeapToStack(*CB))
2717         continue;
2718 
2719       // Find the unique free call to remove it.
2720       SmallVector<CallBase *, 4> FreeCalls;
2721       for (auto *U : CB->users()) {
2722         CallBase *C = dyn_cast<CallBase>(U);
2723         if (C && C->getCalledFunction() == FreeCall.Declaration)
2724           FreeCalls.push_back(C);
2725       }
2726       if (FreeCalls.size() != 1)
2727         continue;
2728 
2729       ConstantInt *AllocSize = dyn_cast<ConstantInt>(CB->getArgOperand(0));
2730 
2731       LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
2732                         << " with " << AllocSize->getZExtValue()
2733                         << " bytes of shared memory\n");
2734 
2735       // Create a new shared memory buffer of the same size as the allocation
2736       // and replace all the uses of the original allocation with it.
2737       Module *M = CB->getModule();
2738       Type *Int8Ty = Type::getInt8Ty(M->getContext());
2739       Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
2740       auto *SharedMem = new GlobalVariable(
2741           *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
2742           UndefValue::get(Int8ArrTy), CB->getName(), nullptr,
2743           GlobalValue::NotThreadLocal,
2744           static_cast<unsigned>(AddressSpace::Shared));
2745       auto *NewBuffer =
2746           ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
2747 
2748       auto Remark = [&](OptimizationRemark OR) {
2749         return OR << "Replaced globalized variable with "
2750                   << ore::NV("SharedMemory", AllocSize->getZExtValue())
2751                   << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ")
2752                   << "of shared memory.";
2753       };
2754       A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
2755 
2756       SharedMem->setAlignment(MaybeAlign(32));
2757 
2758       A.changeValueAfterManifest(*CB, *NewBuffer);
2759       A.deleteAfterManifest(*CB);
2760       A.deleteAfterManifest(*FreeCalls.front());
2761 
2762       NumBytesMovedToSharedMemory += AllocSize->getZExtValue();
2763       Changed = ChangeStatus::CHANGED;
2764     }
2765 
2766     return Changed;
2767   }
2768 
2769   ChangeStatus updateImpl(Attributor &A) override {
2770     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2771     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2772     Function *F = getAnchorScope();
2773 
2774     auto NumMallocCalls = MallocCalls.size();
2775 
2776     // Only consider malloc calls executed by a single thread with a constant.
2777     for (User *U : RFI.Declaration->users()) {
2778       const auto &ED = A.getAAFor<AAExecutionDomain>(
2779           *this, IRPosition::function(*F), DepClassTy::REQUIRED);
2780       if (CallBase *CB = dyn_cast<CallBase>(U))
2781         if (!dyn_cast<ConstantInt>(CB->getArgOperand(0)) ||
2782             !ED.isExecutedByInitialThreadOnly(*CB))
2783           MallocCalls.erase(CB);
2784     }
2785 
2786     findPotentialRemovedFreeCalls(A);
2787 
2788     if (NumMallocCalls != MallocCalls.size())
2789       return ChangeStatus::CHANGED;
2790 
2791     return ChangeStatus::UNCHANGED;
2792   }
2793 
2794   /// Collection of all malloc calls in a function.
2795   SmallPtrSet<CallBase *, 4> MallocCalls;
2796   /// Collection of potentially removed free calls in a function.
2797   SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
2798 };
2799 
2800 struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
2801   using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
2802   AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2803 
2804   /// Statistics are tracked as part of manifest for now.
2805   void trackStatistics() const override {}
2806 
2807   /// See AbstractAttribute::getAsStr()
2808   const std::string getAsStr() const override {
2809     if (!isValidState())
2810       return "<invalid>";
2811     return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
2812                                                             : "generic") +
2813            std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
2814                                                                : "") +
2815            std::string(" #PRs: ") +
2816            std::to_string(ReachedKnownParallelRegions.size()) +
2817            ", #Unknown PRs: " +
2818            std::to_string(ReachedUnknownParallelRegions.size());
2819   }
2820 
2821   /// Create an abstract attribute biew for the position \p IRP.
2822   static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
2823 
2824   /// See AbstractAttribute::getName()
2825   const std::string getName() const override { return "AAKernelInfo"; }
2826 
2827   /// See AbstractAttribute::getIdAddr()
2828   const char *getIdAddr() const override { return &ID; }
2829 
2830   /// This function should return true if the type of the \p AA is AAKernelInfo
2831   static bool classof(const AbstractAttribute *AA) {
2832     return (AA->getIdAddr() == &ID);
2833   }
2834 
2835   static const char ID;
2836 };
2837 
2838 /// The function kernel info abstract attribute, basically, what can we say
2839 /// about a function with regards to the KernelInfoState.
2840 struct AAKernelInfoFunction : AAKernelInfo {
2841   AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
2842       : AAKernelInfo(IRP, A) {}
2843 
2844   SmallPtrSet<Instruction *, 4> GuardedInstructions;
2845 
2846   SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
2847     return GuardedInstructions;
2848   }
2849 
2850   /// See AbstractAttribute::initialize(...).
2851   void initialize(Attributor &A) override {
2852     // This is a high-level transform that might change the constant arguments
2853     // of the init and dinit calls. We need to tell the Attributor about this
2854     // to avoid other parts using the current constant value for simpliication.
2855     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2856 
2857     Function *Fn = getAnchorScope();
2858     if (!OMPInfoCache.Kernels.count(Fn))
2859       return;
2860 
2861     // Add itself to the reaching kernel and set IsKernelEntry.
2862     ReachingKernelEntries.insert(Fn);
2863     IsKernelEntry = true;
2864 
2865     OMPInformationCache::RuntimeFunctionInfo &InitRFI =
2866         OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2867     OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
2868         OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
2869 
2870     // For kernels we perform more initialization work, first we find the init
2871     // and deinit calls.
2872     auto StoreCallBase = [](Use &U,
2873                             OMPInformationCache::RuntimeFunctionInfo &RFI,
2874                             CallBase *&Storage) {
2875       CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
2876       assert(CB &&
2877              "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
2878       assert(!Storage &&
2879              "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
2880       Storage = CB;
2881       return false;
2882     };
2883     InitRFI.foreachUse(
2884         [&](Use &U, Function &) {
2885           StoreCallBase(U, InitRFI, KernelInitCB);
2886           return false;
2887         },
2888         Fn);
2889     DeinitRFI.foreachUse(
2890         [&](Use &U, Function &) {
2891           StoreCallBase(U, DeinitRFI, KernelDeinitCB);
2892           return false;
2893         },
2894         Fn);
2895 
2896     // Ignore kernels without initializers such as global constructors.
2897     if (!KernelInitCB || !KernelDeinitCB) {
2898       indicateOptimisticFixpoint();
2899       return;
2900     }
2901 
2902     // For kernels we might need to initialize/finalize the IsSPMD state and
2903     // we need to register a simplification callback so that the Attributor
2904     // knows the constant arguments to __kmpc_target_init and
2905     // __kmpc_target_deinit might actually change.
2906 
2907     Attributor::SimplifictionCallbackTy StateMachineSimplifyCB =
2908         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2909             bool &UsedAssumedInformation) -> Optional<Value *> {
2910       // IRP represents the "use generic state machine" argument of an
2911       // __kmpc_target_init call. We will answer this one with the internal
2912       // state. As long as we are not in an invalid state, we will create a
2913       // custom state machine so the value should be a `i1 false`. If we are
2914       // in an invalid state, we won't change the value that is in the IR.
2915       if (!isValidState())
2916         return nullptr;
2917       // If we have disabled state machine rewrites, don't make a custom one.
2918       if (DisableOpenMPOptStateMachineRewrite)
2919         return nullptr;
2920       if (AA)
2921         A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2922       UsedAssumedInformation = !isAtFixpoint();
2923       auto *FalseVal =
2924           ConstantInt::getBool(IRP.getAnchorValue().getContext(), 0);
2925       return FalseVal;
2926     };
2927 
2928     Attributor::SimplifictionCallbackTy IsSPMDModeSimplifyCB =
2929         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2930             bool &UsedAssumedInformation) -> Optional<Value *> {
2931       // IRP represents the "SPMDCompatibilityTracker" argument of an
2932       // __kmpc_target_init or
2933       // __kmpc_target_deinit call. We will answer this one with the internal
2934       // state.
2935       if (!SPMDCompatibilityTracker.isValidState())
2936         return nullptr;
2937       if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2938         if (AA)
2939           A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2940         UsedAssumedInformation = true;
2941       } else {
2942         UsedAssumedInformation = false;
2943       }
2944       auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
2945                                        SPMDCompatibilityTracker.isAssumed());
2946       return Val;
2947     };
2948 
2949     Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB =
2950         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2951             bool &UsedAssumedInformation) -> Optional<Value *> {
2952       // IRP represents the "RequiresFullRuntime" argument of an
2953       // __kmpc_target_init or __kmpc_target_deinit call. We will answer this
2954       // one with the internal state of the SPMDCompatibilityTracker, so if
2955       // generic then true, if SPMD then false.
2956       if (!SPMDCompatibilityTracker.isValidState())
2957         return nullptr;
2958       if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2959         if (AA)
2960           A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2961         UsedAssumedInformation = true;
2962       } else {
2963         UsedAssumedInformation = false;
2964       }
2965       auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
2966                                        !SPMDCompatibilityTracker.isAssumed());
2967       return Val;
2968     };
2969 
2970     constexpr const int InitIsSPMDArgNo = 1;
2971     constexpr const int DeinitIsSPMDArgNo = 1;
2972     constexpr const int InitUseStateMachineArgNo = 2;
2973     constexpr const int InitRequiresFullRuntimeArgNo = 3;
2974     constexpr const int DeinitRequiresFullRuntimeArgNo = 2;
2975     A.registerSimplificationCallback(
2976         IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo),
2977         StateMachineSimplifyCB);
2978     A.registerSimplificationCallback(
2979         IRPosition::callsite_argument(*KernelInitCB, InitIsSPMDArgNo),
2980         IsSPMDModeSimplifyCB);
2981     A.registerSimplificationCallback(
2982         IRPosition::callsite_argument(*KernelDeinitCB, DeinitIsSPMDArgNo),
2983         IsSPMDModeSimplifyCB);
2984     A.registerSimplificationCallback(
2985         IRPosition::callsite_argument(*KernelInitCB,
2986                                       InitRequiresFullRuntimeArgNo),
2987         IsGenericModeSimplifyCB);
2988     A.registerSimplificationCallback(
2989         IRPosition::callsite_argument(*KernelDeinitCB,
2990                                       DeinitRequiresFullRuntimeArgNo),
2991         IsGenericModeSimplifyCB);
2992 
2993     // Check if we know we are in SPMD-mode already.
2994     ConstantInt *IsSPMDArg =
2995         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo));
2996     if (IsSPMDArg && !IsSPMDArg->isZero())
2997       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
2998     // This is a generic region but SPMDization is disabled so stop tracking.
2999     else if (DisableOpenMPOptSPMDization)
3000       SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3001   }
3002 
3003   /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3004   /// finished now.
3005   ChangeStatus manifest(Attributor &A) override {
3006     // If we are not looking at a kernel with __kmpc_target_init and
3007     // __kmpc_target_deinit call we cannot actually manifest the information.
3008     if (!KernelInitCB || !KernelDeinitCB)
3009       return ChangeStatus::UNCHANGED;
3010 
3011     // Known SPMD-mode kernels need no manifest changes.
3012     if (SPMDCompatibilityTracker.isKnown())
3013       return ChangeStatus::UNCHANGED;
3014 
3015     // If we can we change the execution mode to SPMD-mode otherwise we build a
3016     // custom state machine.
3017     if (!mayContainParallelRegion() || !changeToSPMDMode(A))
3018       buildCustomStateMachine(A);
3019 
3020     return ChangeStatus::CHANGED;
3021   }
3022 
3023   bool changeToSPMDMode(Attributor &A) {
3024     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3025 
3026     if (!SPMDCompatibilityTracker.isAssumed()) {
3027       for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
3028         if (!NonCompatibleI)
3029           continue;
3030 
3031         // Skip diagnostics on calls to known OpenMP runtime functions for now.
3032         if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
3033           if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
3034             continue;
3035 
3036         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3037           ORA << "Value has potential side effects preventing SPMD-mode "
3038                  "execution";
3039           if (isa<CallBase>(NonCompatibleI)) {
3040             ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
3041                    "the called function to override";
3042           }
3043           return ORA << ".";
3044         };
3045         A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
3046                                                  Remark);
3047 
3048         LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
3049                           << *NonCompatibleI << "\n");
3050       }
3051 
3052       return false;
3053     }
3054 
3055     auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3056                                    Instruction *RegionEndI) {
3057       LoopInfo *LI = nullptr;
3058       DominatorTree *DT = nullptr;
3059       MemorySSAUpdater *MSU = nullptr;
3060       using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3061 
3062       BasicBlock *ParentBB = RegionStartI->getParent();
3063       Function *Fn = ParentBB->getParent();
3064       Module &M = *Fn->getParent();
3065 
3066       // Create all the blocks and logic.
3067       // ParentBB:
3068       //    goto RegionCheckTidBB
3069       // RegionCheckTidBB:
3070       //    Tid = __kmpc_hardware_thread_id()
3071       //    if (Tid != 0)
3072       //        goto RegionBarrierBB
3073       // RegionStartBB:
3074       //    <execute instructions guarded>
3075       //    goto RegionEndBB
3076       // RegionEndBB:
3077       //    <store escaping values to shared mem>
3078       //    goto RegionBarrierBB
3079       //  RegionBarrierBB:
3080       //    __kmpc_simple_barrier_spmd()
3081       //    // second barrier is omitted if lacking escaping values.
3082       //    <load escaping values from shared mem>
3083       //    __kmpc_simple_barrier_spmd()
3084       //    goto RegionExitBB
3085       // RegionExitBB:
3086       //    <execute rest of instructions>
3087 
3088       BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3089                                            DT, LI, MSU, "region.guarded.end");
3090       BasicBlock *RegionBarrierBB =
3091           SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3092                      MSU, "region.barrier");
3093       BasicBlock *RegionExitBB =
3094           SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3095                      DT, LI, MSU, "region.exit");
3096       BasicBlock *RegionStartBB =
3097           SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3098 
3099       assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3100              "Expected a different CFG");
3101 
3102       BasicBlock *RegionCheckTidBB = SplitBlock(
3103           ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3104 
3105       // Register basic blocks with the Attributor.
3106       A.registerManifestAddedBasicBlock(*RegionEndBB);
3107       A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3108       A.registerManifestAddedBasicBlock(*RegionExitBB);
3109       A.registerManifestAddedBasicBlock(*RegionStartBB);
3110       A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
3111 
3112       bool HasBroadcastValues = false;
3113       // Find escaping outputs from the guarded region to outside users and
3114       // broadcast their values to them.
3115       for (Instruction &I : *RegionStartBB) {
3116         SmallPtrSet<Instruction *, 4> OutsideUsers;
3117         for (User *Usr : I.users()) {
3118           Instruction &UsrI = *cast<Instruction>(Usr);
3119           if (UsrI.getParent() != RegionStartBB)
3120             OutsideUsers.insert(&UsrI);
3121         }
3122 
3123         if (OutsideUsers.empty())
3124           continue;
3125 
3126         HasBroadcastValues = true;
3127 
3128         // Emit a global variable in shared memory to store the broadcasted
3129         // value.
3130         auto *SharedMem = new GlobalVariable(
3131             M, I.getType(), /* IsConstant */ false,
3132             GlobalValue::InternalLinkage, UndefValue::get(I.getType()),
3133             I.getName() + ".guarded.output.alloc", nullptr,
3134             GlobalValue::NotThreadLocal,
3135             static_cast<unsigned>(AddressSpace::Shared));
3136 
3137         // Emit a store instruction to update the value.
3138         new StoreInst(&I, SharedMem, RegionEndBB->getTerminator());
3139 
3140         LoadInst *LoadI = new LoadInst(I.getType(), SharedMem,
3141                                        I.getName() + ".guarded.output.load",
3142                                        RegionBarrierBB->getTerminator());
3143 
3144         // Emit a load instruction and replace uses of the output value.
3145         for (Instruction *UsrI : OutsideUsers) {
3146           assert(UsrI->getParent() == RegionExitBB &&
3147                  "Expected escaping users in exit region");
3148           UsrI->replaceUsesOfWith(&I, LoadI);
3149         }
3150       }
3151 
3152       auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3153 
3154       // Go to tid check BB in ParentBB.
3155       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
3156       ParentBB->getTerminator()->eraseFromParent();
3157       OpenMPIRBuilder::LocationDescription Loc(
3158           InsertPointTy(ParentBB, ParentBB->end()), DL);
3159       OMPInfoCache.OMPBuilder.updateToLocation(Loc);
3160       auto *SrcLocStr = OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc);
3161       Value *Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr);
3162       BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
3163 
3164       // Add check for Tid in RegionCheckTidBB
3165       RegionCheckTidBB->getTerminator()->eraseFromParent();
3166       OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
3167           InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
3168       OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
3169       FunctionCallee HardwareTidFn =
3170           OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3171               M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3172       Value *Tid =
3173           OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
3174       Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
3175       OMPInfoCache.OMPBuilder.Builder
3176           .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
3177           ->setDebugLoc(DL);
3178 
3179       // First barrier for synchronization, ensures main thread has updated
3180       // values.
3181       FunctionCallee BarrierFn =
3182           OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3183               M, OMPRTL___kmpc_barrier_simple_spmd);
3184       OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
3185           RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
3186       OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid})
3187           ->setDebugLoc(DL);
3188 
3189       // Second barrier ensures workers have read broadcast values.
3190       if (HasBroadcastValues)
3191         CallInst::Create(BarrierFn, {Ident, Tid}, "",
3192                          RegionBarrierBB->getTerminator())
3193             ->setDebugLoc(DL);
3194     };
3195 
3196     auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3197     SmallPtrSet<BasicBlock *, 8> Visited;
3198     for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3199       BasicBlock *BB = GuardedI->getParent();
3200       if (!Visited.insert(BB).second)
3201         continue;
3202 
3203       SmallVector<std::pair<Instruction *, Instruction *>> Reorders;
3204       Instruction *LastEffect = nullptr;
3205       BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
3206       while (++IP != IPEnd) {
3207         if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
3208           continue;
3209         Instruction *I = &*IP;
3210         if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
3211           continue;
3212         if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
3213           LastEffect = nullptr;
3214           continue;
3215         }
3216         if (LastEffect)
3217           Reorders.push_back({I, LastEffect});
3218         LastEffect = &*IP;
3219       }
3220       for (auto &Reorder : Reorders)
3221         Reorder.first->moveBefore(Reorder.second);
3222     }
3223 
3224     SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions;
3225 
3226     for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3227       BasicBlock *BB = GuardedI->getParent();
3228       auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
3229           IRPosition::function(*GuardedI->getFunction()), nullptr,
3230           DepClassTy::NONE);
3231       assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
3232       auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
3233       // Continue if instruction is already guarded.
3234       if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
3235         continue;
3236 
3237       Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
3238       for (Instruction &I : *BB) {
3239         // If instruction I needs to be guarded update the guarded region
3240         // bounds.
3241         if (SPMDCompatibilityTracker.contains(&I)) {
3242           CalleeAAFunction.getGuardedInstructions().insert(&I);
3243           if (GuardedRegionStart)
3244             GuardedRegionEnd = &I;
3245           else
3246             GuardedRegionStart = GuardedRegionEnd = &I;
3247 
3248           continue;
3249         }
3250 
3251         // Instruction I does not need guarding, store
3252         // any region found and reset bounds.
3253         if (GuardedRegionStart) {
3254           GuardedRegions.push_back(
3255               std::make_pair(GuardedRegionStart, GuardedRegionEnd));
3256           GuardedRegionStart = nullptr;
3257           GuardedRegionEnd = nullptr;
3258         }
3259       }
3260     }
3261 
3262     for (auto &GR : GuardedRegions)
3263       CreateGuardedRegion(GR.first, GR.second);
3264 
3265     // Adjust the global exec mode flag that tells the runtime what mode this
3266     // kernel is executed in.
3267     Function *Kernel = getAnchorScope();
3268     GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable(
3269         (Kernel->getName() + "_exec_mode").str());
3270     assert(ExecMode && "Kernel without exec mode?");
3271     assert(ExecMode->getInitializer() &&
3272            ExecMode->getInitializer()->isOneValue() &&
3273            "Initially non-SPMD kernel has SPMD exec mode!");
3274 
3275     // Set the global exec mode flag to indicate SPMD-Generic mode.
3276     constexpr int SPMDGeneric = 2;
3277     if (!ExecMode->getInitializer()->isZeroValue())
3278       ExecMode->setInitializer(
3279           ConstantInt::get(ExecMode->getInitializer()->getType(), SPMDGeneric));
3280 
3281     // Next rewrite the init and deinit calls to indicate we use SPMD-mode now.
3282     const int InitIsSPMDArgNo = 1;
3283     const int DeinitIsSPMDArgNo = 1;
3284     const int InitUseStateMachineArgNo = 2;
3285     const int InitRequiresFullRuntimeArgNo = 3;
3286     const int DeinitRequiresFullRuntimeArgNo = 2;
3287 
3288     auto &Ctx = getAnchorValue().getContext();
3289     A.changeUseAfterManifest(KernelInitCB->getArgOperandUse(InitIsSPMDArgNo),
3290                              *ConstantInt::getBool(Ctx, 1));
3291     A.changeUseAfterManifest(
3292         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
3293         *ConstantInt::getBool(Ctx, 0));
3294     A.changeUseAfterManifest(
3295         KernelDeinitCB->getArgOperandUse(DeinitIsSPMDArgNo),
3296         *ConstantInt::getBool(Ctx, 1));
3297     A.changeUseAfterManifest(
3298         KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo),
3299         *ConstantInt::getBool(Ctx, 0));
3300     A.changeUseAfterManifest(
3301         KernelDeinitCB->getArgOperandUse(DeinitRequiresFullRuntimeArgNo),
3302         *ConstantInt::getBool(Ctx, 0));
3303 
3304     ++NumOpenMPTargetRegionKernelsSPMD;
3305 
3306     auto Remark = [&](OptimizationRemark OR) {
3307       return OR << "Transformed generic-mode kernel to SPMD-mode.";
3308     };
3309     A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
3310     return true;
3311   };
3312 
3313   ChangeStatus buildCustomStateMachine(Attributor &A) {
3314     // If we have disabled state machine rewrites, don't make a custom one
3315     if (DisableOpenMPOptStateMachineRewrite)
3316       return indicatePessimisticFixpoint();
3317 
3318     assert(ReachedKnownParallelRegions.isValidState() &&
3319            "Custom state machine with invalid parallel region states?");
3320 
3321     const int InitIsSPMDArgNo = 1;
3322     const int InitUseStateMachineArgNo = 2;
3323 
3324     // Check if the current configuration is non-SPMD and generic state machine.
3325     // If we already have SPMD mode or a custom state machine we do not need to
3326     // go any further. If it is anything but a constant something is weird and
3327     // we give up.
3328     ConstantInt *UseStateMachine = dyn_cast<ConstantInt>(
3329         KernelInitCB->getArgOperand(InitUseStateMachineArgNo));
3330     ConstantInt *IsSPMD =
3331         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo));
3332 
3333     // If we are stuck with generic mode, try to create a custom device (=GPU)
3334     // state machine which is specialized for the parallel regions that are
3335     // reachable by the kernel.
3336     if (!UseStateMachine || UseStateMachine->isZero() || !IsSPMD ||
3337         !IsSPMD->isZero())
3338       return ChangeStatus::UNCHANGED;
3339 
3340     // If not SPMD mode, indicate we use a custom state machine now.
3341     auto &Ctx = getAnchorValue().getContext();
3342     auto *FalseVal = ConstantInt::getBool(Ctx, 0);
3343     A.changeUseAfterManifest(
3344         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal);
3345 
3346     // If we don't actually need a state machine we are done here. This can
3347     // happen if there simply are no parallel regions. In the resulting kernel
3348     // all worker threads will simply exit right away, leaving the main thread
3349     // to do the work alone.
3350     if (!mayContainParallelRegion()) {
3351       ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
3352 
3353       auto Remark = [&](OptimizationRemark OR) {
3354         return OR << "Removing unused state machine from generic-mode kernel.";
3355       };
3356       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
3357 
3358       return ChangeStatus::CHANGED;
3359     }
3360 
3361     // Keep track in the statistics of our new shiny custom state machine.
3362     if (ReachedUnknownParallelRegions.empty()) {
3363       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
3364 
3365       auto Remark = [&](OptimizationRemark OR) {
3366         return OR << "Rewriting generic-mode kernel with a customized state "
3367                      "machine.";
3368       };
3369       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
3370     } else {
3371       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
3372 
3373       auto Remark = [&](OptimizationRemarkAnalysis OR) {
3374         return OR << "Generic-mode kernel is executed with a customized state "
3375                      "machine that requires a fallback.";
3376       };
3377       A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
3378 
3379       // Tell the user why we ended up with a fallback.
3380       for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
3381         if (!UnknownParallelRegionCB)
3382           continue;
3383         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3384           return ORA << "Call may contain unknown parallel regions. Use "
3385                      << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
3386                         "override.";
3387         };
3388         A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
3389                                                  "OMP133", Remark);
3390       }
3391     }
3392 
3393     // Create all the blocks:
3394     //
3395     //                       InitCB = __kmpc_target_init(...)
3396     //                       bool IsWorker = InitCB >= 0;
3397     //                       if (IsWorker) {
3398     // SMBeginBB:               __kmpc_barrier_simple_spmd(...);
3399     //                         void *WorkFn;
3400     //                         bool Active = __kmpc_kernel_parallel(&WorkFn);
3401     //                         if (!WorkFn) return;
3402     // SMIsActiveCheckBB:       if (Active) {
3403     // SMIfCascadeCurrentBB:      if      (WorkFn == <ParFn0>)
3404     //                              ParFn0(...);
3405     // SMIfCascadeCurrentBB:      else if (WorkFn == <ParFn1>)
3406     //                              ParFn1(...);
3407     //                            ...
3408     // SMIfCascadeCurrentBB:      else
3409     //                              ((WorkFnTy*)WorkFn)(...);
3410     // SMEndParallelBB:           __kmpc_kernel_end_parallel(...);
3411     //                          }
3412     // SMDoneBB:                __kmpc_barrier_simple_spmd(...);
3413     //                          goto SMBeginBB;
3414     //                       }
3415     // UserCodeEntryBB:      // user code
3416     //                       __kmpc_target_deinit(...)
3417     //
3418     Function *Kernel = getAssociatedFunction();
3419     assert(Kernel && "Expected an associated function!");
3420 
3421     BasicBlock *InitBB = KernelInitCB->getParent();
3422     BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
3423         KernelInitCB->getNextNode(), "thread.user_code.check");
3424     BasicBlock *StateMachineBeginBB = BasicBlock::Create(
3425         Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
3426     BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
3427         Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
3428     BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
3429         Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
3430     BasicBlock *StateMachineIfCascadeCurrentBB =
3431         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3432                            Kernel, UserCodeEntryBB);
3433     BasicBlock *StateMachineEndParallelBB =
3434         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
3435                            Kernel, UserCodeEntryBB);
3436     BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
3437         Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
3438     A.registerManifestAddedBasicBlock(*InitBB);
3439     A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
3440     A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
3441     A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
3442     A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
3443     A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
3444     A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
3445     A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
3446 
3447     const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
3448     ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
3449 
3450     InitBB->getTerminator()->eraseFromParent();
3451     Instruction *IsWorker =
3452         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
3453                          ConstantInt::get(KernelInitCB->getType(), -1),
3454                          "thread.is_worker", InitBB);
3455     IsWorker->setDebugLoc(DLoc);
3456     BranchInst::Create(StateMachineBeginBB, UserCodeEntryBB, IsWorker, InitBB);
3457 
3458     // Create local storage for the work function pointer.
3459     Type *VoidPtrTy = Type::getInt8PtrTy(Ctx);
3460     AllocaInst *WorkFnAI = new AllocaInst(VoidPtrTy, 0, "worker.work_fn.addr",
3461                                           &Kernel->getEntryBlock().front());
3462     WorkFnAI->setDebugLoc(DLoc);
3463 
3464     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3465     OMPInfoCache.OMPBuilder.updateToLocation(
3466         OpenMPIRBuilder::LocationDescription(
3467             IRBuilder<>::InsertPoint(StateMachineBeginBB,
3468                                      StateMachineBeginBB->end()),
3469             DLoc));
3470 
3471     Value *Ident = KernelInitCB->getArgOperand(0);
3472     Value *GTid = KernelInitCB;
3473 
3474     Module &M = *Kernel->getParent();
3475     FunctionCallee BarrierFn =
3476         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3477             M, OMPRTL___kmpc_barrier_simple_spmd);
3478     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB)
3479         ->setDebugLoc(DLoc);
3480 
3481     FunctionCallee KernelParallelFn =
3482         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3483             M, OMPRTL___kmpc_kernel_parallel);
3484     Instruction *IsActiveWorker = CallInst::Create(
3485         KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
3486     IsActiveWorker->setDebugLoc(DLoc);
3487     Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
3488                                        StateMachineBeginBB);
3489     WorkFn->setDebugLoc(DLoc);
3490 
3491     FunctionType *ParallelRegionFnTy = FunctionType::get(
3492         Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
3493         false);
3494     Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast(
3495         WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast",
3496         StateMachineBeginBB);
3497 
3498     Instruction *IsDone =
3499         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
3500                          Constant::getNullValue(VoidPtrTy), "worker.is_done",
3501                          StateMachineBeginBB);
3502     IsDone->setDebugLoc(DLoc);
3503     BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
3504                        IsDone, StateMachineBeginBB)
3505         ->setDebugLoc(DLoc);
3506 
3507     BranchInst::Create(StateMachineIfCascadeCurrentBB,
3508                        StateMachineDoneBarrierBB, IsActiveWorker,
3509                        StateMachineIsActiveCheckBB)
3510         ->setDebugLoc(DLoc);
3511 
3512     Value *ZeroArg =
3513         Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
3514 
3515     // Now that we have most of the CFG skeleton it is time for the if-cascade
3516     // that checks the function pointer we got from the runtime against the
3517     // parallel regions we expect, if there are any.
3518     for (int i = 0, e = ReachedKnownParallelRegions.size(); i < e; ++i) {
3519       auto *ParallelRegion = ReachedKnownParallelRegions[i];
3520       BasicBlock *PRExecuteBB = BasicBlock::Create(
3521           Ctx, "worker_state_machine.parallel_region.execute", Kernel,
3522           StateMachineEndParallelBB);
3523       CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
3524           ->setDebugLoc(DLoc);
3525       BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
3526           ->setDebugLoc(DLoc);
3527 
3528       BasicBlock *PRNextBB =
3529           BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3530                              Kernel, StateMachineEndParallelBB);
3531 
3532       // Check if we need to compare the pointer at all or if we can just
3533       // call the parallel region function.
3534       Value *IsPR;
3535       if (i + 1 < e || !ReachedUnknownParallelRegions.empty()) {
3536         Instruction *CmpI = ICmpInst::Create(
3537             ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion,
3538             "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
3539         CmpI->setDebugLoc(DLoc);
3540         IsPR = CmpI;
3541       } else {
3542         IsPR = ConstantInt::getTrue(Ctx);
3543       }
3544 
3545       BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
3546                          StateMachineIfCascadeCurrentBB)
3547           ->setDebugLoc(DLoc);
3548       StateMachineIfCascadeCurrentBB = PRNextBB;
3549     }
3550 
3551     // At the end of the if-cascade we place the indirect function pointer call
3552     // in case we might need it, that is if there can be parallel regions we
3553     // have not handled in the if-cascade above.
3554     if (!ReachedUnknownParallelRegions.empty()) {
3555       StateMachineIfCascadeCurrentBB->setName(
3556           "worker_state_machine.parallel_region.fallback.execute");
3557       CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "",
3558                        StateMachineIfCascadeCurrentBB)
3559           ->setDebugLoc(DLoc);
3560     }
3561     BranchInst::Create(StateMachineEndParallelBB,
3562                        StateMachineIfCascadeCurrentBB)
3563         ->setDebugLoc(DLoc);
3564 
3565     CallInst::Create(OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3566                          M, OMPRTL___kmpc_kernel_end_parallel),
3567                      {}, "", StateMachineEndParallelBB)
3568         ->setDebugLoc(DLoc);
3569     BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
3570         ->setDebugLoc(DLoc);
3571 
3572     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
3573         ->setDebugLoc(DLoc);
3574     BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
3575         ->setDebugLoc(DLoc);
3576 
3577     return ChangeStatus::CHANGED;
3578   }
3579 
3580   /// Fixpoint iteration update function. Will be called every time a dependence
3581   /// changed its state (and in the beginning).
3582   ChangeStatus updateImpl(Attributor &A) override {
3583     KernelInfoState StateBefore = getState();
3584 
3585     // Callback to check a read/write instruction.
3586     auto CheckRWInst = [&](Instruction &I) {
3587       // We handle calls later.
3588       if (isa<CallBase>(I))
3589         return true;
3590       // We only care about write effects.
3591       if (!I.mayWriteToMemory())
3592         return true;
3593       if (auto *SI = dyn_cast<StoreInst>(&I)) {
3594         SmallVector<const Value *> Objects;
3595         getUnderlyingObjects(SI->getPointerOperand(), Objects);
3596         if (llvm::all_of(Objects,
3597                          [](const Value *Obj) { return isa<AllocaInst>(Obj); }))
3598           return true;
3599         // Check for AAHeapToStack moved objects which must not be guarded.
3600         auto &HS = A.getAAFor<AAHeapToStack>(
3601             *this, IRPosition::function(*I.getFunction()),
3602             DepClassTy::REQUIRED);
3603         if (llvm::all_of(Objects, [&HS](const Value *Obj) {
3604               auto *CB = dyn_cast<CallBase>(Obj);
3605               if (!CB)
3606                 return false;
3607               return HS.isAssumedHeapToStack(*CB);
3608             })) {
3609           return true;
3610         }
3611       }
3612 
3613       // Insert instruction that needs guarding.
3614       SPMDCompatibilityTracker.insert(&I);
3615       return true;
3616     };
3617 
3618     bool UsedAssumedInformationInCheckRWInst = false;
3619     if (!SPMDCompatibilityTracker.isAtFixpoint())
3620       if (!A.checkForAllReadWriteInstructions(
3621               CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
3622         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3623 
3624     if (!IsKernelEntry) {
3625       updateReachingKernelEntries(A);
3626       updateParallelLevels(A);
3627 
3628       if (!ParallelLevels.isValidState())
3629         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3630     }
3631 
3632     // Callback to check a call instruction.
3633     bool AllSPMDStatesWereFixed = true;
3634     auto CheckCallInst = [&](Instruction &I) {
3635       auto &CB = cast<CallBase>(I);
3636       auto &CBAA = A.getAAFor<AAKernelInfo>(
3637           *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
3638       getState() ^= CBAA.getState();
3639       AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint();
3640       return true;
3641     };
3642 
3643     bool UsedAssumedInformationInCheckCallInst = false;
3644     if (!A.checkForAllCallLikeInstructions(
3645             CheckCallInst, *this, UsedAssumedInformationInCheckCallInst))
3646       return indicatePessimisticFixpoint();
3647 
3648     // If we haven't used any assumed information for the SPMD state we can fix
3649     // it.
3650     if (!UsedAssumedInformationInCheckRWInst &&
3651         !UsedAssumedInformationInCheckCallInst && AllSPMDStatesWereFixed)
3652       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3653 
3654     return StateBefore == getState() ? ChangeStatus::UNCHANGED
3655                                      : ChangeStatus::CHANGED;
3656   }
3657 
3658 private:
3659   /// Update info regarding reaching kernels.
3660   void updateReachingKernelEntries(Attributor &A) {
3661     auto PredCallSite = [&](AbstractCallSite ACS) {
3662       Function *Caller = ACS.getInstruction()->getFunction();
3663 
3664       assert(Caller && "Caller is nullptr");
3665 
3666       auto &CAA = A.getOrCreateAAFor<AAKernelInfo>(
3667           IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
3668       if (CAA.ReachingKernelEntries.isValidState()) {
3669         ReachingKernelEntries ^= CAA.ReachingKernelEntries;
3670         return true;
3671       }
3672 
3673       // We lost track of the caller of the associated function, any kernel
3674       // could reach now.
3675       ReachingKernelEntries.indicatePessimisticFixpoint();
3676 
3677       return true;
3678     };
3679 
3680     bool AllCallSitesKnown;
3681     if (!A.checkForAllCallSites(PredCallSite, *this,
3682                                 true /* RequireAllCallSites */,
3683                                 AllCallSitesKnown))
3684       ReachingKernelEntries.indicatePessimisticFixpoint();
3685   }
3686 
3687   /// Update info regarding parallel levels.
3688   void updateParallelLevels(Attributor &A) {
3689     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3690     OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
3691         OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
3692 
3693     auto PredCallSite = [&](AbstractCallSite ACS) {
3694       Function *Caller = ACS.getInstruction()->getFunction();
3695 
3696       assert(Caller && "Caller is nullptr");
3697 
3698       auto &CAA =
3699           A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
3700       if (CAA.ParallelLevels.isValidState()) {
3701         // Any function that is called by `__kmpc_parallel_51` will not be
3702         // folded as the parallel level in the function is updated. In order to
3703         // get it right, all the analysis would depend on the implentation. That
3704         // said, if in the future any change to the implementation, the analysis
3705         // could be wrong. As a consequence, we are just conservative here.
3706         if (Caller == Parallel51RFI.Declaration) {
3707           ParallelLevels.indicatePessimisticFixpoint();
3708           return true;
3709         }
3710 
3711         ParallelLevels ^= CAA.ParallelLevels;
3712 
3713         return true;
3714       }
3715 
3716       // We lost track of the caller of the associated function, any kernel
3717       // could reach now.
3718       ParallelLevels.indicatePessimisticFixpoint();
3719 
3720       return true;
3721     };
3722 
3723     bool AllCallSitesKnown = true;
3724     if (!A.checkForAllCallSites(PredCallSite, *this,
3725                                 true /* RequireAllCallSites */,
3726                                 AllCallSitesKnown))
3727       ParallelLevels.indicatePessimisticFixpoint();
3728   }
3729 };
3730 
3731 /// The call site kernel info abstract attribute, basically, what can we say
3732 /// about a call site with regards to the KernelInfoState. For now this simply
3733 /// forwards the information from the callee.
3734 struct AAKernelInfoCallSite : AAKernelInfo {
3735   AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
3736       : AAKernelInfo(IRP, A) {}
3737 
3738   /// See AbstractAttribute::initialize(...).
3739   void initialize(Attributor &A) override {
3740     AAKernelInfo::initialize(A);
3741 
3742     CallBase &CB = cast<CallBase>(getAssociatedValue());
3743     Function *Callee = getAssociatedFunction();
3744 
3745     // Helper to lookup an assumption string.
3746     auto HasAssumption = [](CallBase &CB, StringRef AssumptionStr) {
3747       return hasAssumption(CB, AssumptionStr);
3748     };
3749 
3750     // Check for SPMD-mode assumptions.
3751     if (HasAssumption(CB, "ompx_spmd_amenable")) {
3752       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3753       indicateOptimisticFixpoint();
3754     }
3755 
3756     // First weed out calls we do not care about, that is readonly/readnone
3757     // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
3758     // parallel region or anything else we are looking for.
3759     if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
3760       indicateOptimisticFixpoint();
3761       return;
3762     }
3763 
3764     // Next we check if we know the callee. If it is a known OpenMP function
3765     // we will handle them explicitly in the switch below. If it is not, we
3766     // will use an AAKernelInfo object on the callee to gather information and
3767     // merge that into the current state. The latter happens in the updateImpl.
3768     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3769     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
3770     if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
3771       // Unknown caller or declarations are not analyzable, we give up.
3772       if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
3773 
3774         // Unknown callees might contain parallel regions, except if they have
3775         // an appropriate assumption attached.
3776         if (!(HasAssumption(CB, "omp_no_openmp") ||
3777               HasAssumption(CB, "omp_no_parallelism")))
3778           ReachedUnknownParallelRegions.insert(&CB);
3779 
3780         // If SPMDCompatibilityTracker is not fixed, we need to give up on the
3781         // idea we can run something unknown in SPMD-mode.
3782         if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3783           SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3784           SPMDCompatibilityTracker.insert(&CB);
3785         }
3786 
3787         // We have updated the state for this unknown call properly, there won't
3788         // be any change so we indicate a fixpoint.
3789         indicateOptimisticFixpoint();
3790       }
3791       // If the callee is known and can be used in IPO, we will update the state
3792       // based on the callee state in updateImpl.
3793       return;
3794     }
3795 
3796     const unsigned int WrapperFunctionArgNo = 6;
3797     RuntimeFunction RF = It->getSecond();
3798     switch (RF) {
3799     // All the functions we know are compatible with SPMD mode.
3800     case OMPRTL___kmpc_is_spmd_exec_mode:
3801     case OMPRTL___kmpc_for_static_fini:
3802     case OMPRTL___kmpc_global_thread_num:
3803     case OMPRTL___kmpc_get_hardware_num_threads_in_block:
3804     case OMPRTL___kmpc_get_hardware_num_blocks:
3805     case OMPRTL___kmpc_single:
3806     case OMPRTL___kmpc_end_single:
3807     case OMPRTL___kmpc_master:
3808     case OMPRTL___kmpc_end_master:
3809     case OMPRTL___kmpc_barrier:
3810       break;
3811     case OMPRTL___kmpc_for_static_init_4:
3812     case OMPRTL___kmpc_for_static_init_4u:
3813     case OMPRTL___kmpc_for_static_init_8:
3814     case OMPRTL___kmpc_for_static_init_8u: {
3815       // Check the schedule and allow static schedule in SPMD mode.
3816       unsigned ScheduleArgOpNo = 2;
3817       auto *ScheduleTypeCI =
3818           dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
3819       unsigned ScheduleTypeVal =
3820           ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
3821       switch (OMPScheduleType(ScheduleTypeVal)) {
3822       case OMPScheduleType::Static:
3823       case OMPScheduleType::StaticChunked:
3824       case OMPScheduleType::Distribute:
3825       case OMPScheduleType::DistributeChunked:
3826         break;
3827       default:
3828         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3829         SPMDCompatibilityTracker.insert(&CB);
3830         break;
3831       };
3832     } break;
3833     case OMPRTL___kmpc_target_init:
3834       KernelInitCB = &CB;
3835       break;
3836     case OMPRTL___kmpc_target_deinit:
3837       KernelDeinitCB = &CB;
3838       break;
3839     case OMPRTL___kmpc_parallel_51:
3840       if (auto *ParallelRegion = dyn_cast<Function>(
3841               CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) {
3842         ReachedKnownParallelRegions.insert(ParallelRegion);
3843         break;
3844       }
3845       // The condition above should usually get the parallel region function
3846       // pointer and record it. In the off chance it doesn't we assume the
3847       // worst.
3848       ReachedUnknownParallelRegions.insert(&CB);
3849       break;
3850     case OMPRTL___kmpc_omp_task:
3851       // We do not look into tasks right now, just give up.
3852       SPMDCompatibilityTracker.insert(&CB);
3853       ReachedUnknownParallelRegions.insert(&CB);
3854       indicatePessimisticFixpoint();
3855       return;
3856     case OMPRTL___kmpc_alloc_shared:
3857     case OMPRTL___kmpc_free_shared:
3858       // Return without setting a fixpoint, to be resolved in updateImpl.
3859       return;
3860     default:
3861       // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
3862       // generally.
3863       SPMDCompatibilityTracker.insert(&CB);
3864       indicatePessimisticFixpoint();
3865       return;
3866     }
3867     // All other OpenMP runtime calls will not reach parallel regions so they
3868     // can be safely ignored for now. Since it is a known OpenMP runtime call we
3869     // have now modeled all effects and there is no need for any update.
3870     indicateOptimisticFixpoint();
3871   }
3872 
3873   ChangeStatus updateImpl(Attributor &A) override {
3874     // TODO: Once we have call site specific value information we can provide
3875     //       call site specific liveness information and then it makes
3876     //       sense to specialize attributes for call sites arguments instead of
3877     //       redirecting requests to the callee argument.
3878     Function *F = getAssociatedFunction();
3879 
3880     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3881     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
3882 
3883     // If F is not a runtime function, propagate the AAKernelInfo of the callee.
3884     if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
3885       const IRPosition &FnPos = IRPosition::function(*F);
3886       auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
3887       if (getState() == FnAA.getState())
3888         return ChangeStatus::UNCHANGED;
3889       getState() = FnAA.getState();
3890       return ChangeStatus::CHANGED;
3891     }
3892 
3893     // F is a runtime function that allocates or frees memory, check
3894     // AAHeapToStack and AAHeapToShared.
3895     KernelInfoState StateBefore = getState();
3896     assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
3897             It->getSecond() == OMPRTL___kmpc_free_shared) &&
3898            "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
3899 
3900     CallBase &CB = cast<CallBase>(getAssociatedValue());
3901 
3902     auto &HeapToStackAA = A.getAAFor<AAHeapToStack>(
3903         *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
3904     auto &HeapToSharedAA = A.getAAFor<AAHeapToShared>(
3905         *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
3906 
3907     RuntimeFunction RF = It->getSecond();
3908 
3909     switch (RF) {
3910     // If neither HeapToStack nor HeapToShared assume the call is removed,
3911     // assume SPMD incompatibility.
3912     case OMPRTL___kmpc_alloc_shared:
3913       if (!HeapToStackAA.isAssumedHeapToStack(CB) &&
3914           !HeapToSharedAA.isAssumedHeapToShared(CB))
3915         SPMDCompatibilityTracker.insert(&CB);
3916       break;
3917     case OMPRTL___kmpc_free_shared:
3918       if (!HeapToStackAA.isAssumedHeapToStackRemovedFree(CB) &&
3919           !HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB))
3920         SPMDCompatibilityTracker.insert(&CB);
3921       break;
3922     default:
3923       SPMDCompatibilityTracker.insert(&CB);
3924     }
3925 
3926     return StateBefore == getState() ? ChangeStatus::UNCHANGED
3927                                      : ChangeStatus::CHANGED;
3928   }
3929 };
3930 
3931 struct AAFoldRuntimeCall
3932     : public StateWrapper<BooleanState, AbstractAttribute> {
3933   using Base = StateWrapper<BooleanState, AbstractAttribute>;
3934 
3935   AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3936 
3937   /// Statistics are tracked as part of manifest for now.
3938   void trackStatistics() const override {}
3939 
3940   /// Create an abstract attribute biew for the position \p IRP.
3941   static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
3942                                               Attributor &A);
3943 
3944   /// See AbstractAttribute::getName()
3945   const std::string getName() const override { return "AAFoldRuntimeCall"; }
3946 
3947   /// See AbstractAttribute::getIdAddr()
3948   const char *getIdAddr() const override { return &ID; }
3949 
3950   /// This function should return true if the type of the \p AA is
3951   /// AAFoldRuntimeCall
3952   static bool classof(const AbstractAttribute *AA) {
3953     return (AA->getIdAddr() == &ID);
3954   }
3955 
3956   static const char ID;
3957 };
3958 
3959 struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
3960   AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
3961       : AAFoldRuntimeCall(IRP, A) {}
3962 
3963   /// See AbstractAttribute::getAsStr()
3964   const std::string getAsStr() const override {
3965     if (!isValidState())
3966       return "<invalid>";
3967 
3968     std::string Str("simplified value: ");
3969 
3970     if (!SimplifiedValue.hasValue())
3971       return Str + std::string("none");
3972 
3973     if (!SimplifiedValue.getValue())
3974       return Str + std::string("nullptr");
3975 
3976     if (ConstantInt *CI = dyn_cast<ConstantInt>(SimplifiedValue.getValue()))
3977       return Str + std::to_string(CI->getSExtValue());
3978 
3979     return Str + std::string("unknown");
3980   }
3981 
3982   void initialize(Attributor &A) override {
3983     if (DisableOpenMPOptFolding)
3984       indicatePessimisticFixpoint();
3985 
3986     Function *Callee = getAssociatedFunction();
3987 
3988     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3989     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
3990     assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
3991            "Expected a known OpenMP runtime function");
3992 
3993     RFKind = It->getSecond();
3994 
3995     CallBase &CB = cast<CallBase>(getAssociatedValue());
3996     A.registerSimplificationCallback(
3997         IRPosition::callsite_returned(CB),
3998         [&](const IRPosition &IRP, const AbstractAttribute *AA,
3999             bool &UsedAssumedInformation) -> Optional<Value *> {
4000           assert((isValidState() || (SimplifiedValue.hasValue() &&
4001                                      SimplifiedValue.getValue() == nullptr)) &&
4002                  "Unexpected invalid state!");
4003 
4004           if (!isAtFixpoint()) {
4005             UsedAssumedInformation = true;
4006             if (AA)
4007               A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
4008           }
4009           return SimplifiedValue;
4010         });
4011   }
4012 
4013   ChangeStatus updateImpl(Attributor &A) override {
4014     ChangeStatus Changed = ChangeStatus::UNCHANGED;
4015     switch (RFKind) {
4016     case OMPRTL___kmpc_is_spmd_exec_mode:
4017       Changed |= foldIsSPMDExecMode(A);
4018       break;
4019     case OMPRTL___kmpc_is_generic_main_thread_id:
4020       Changed |= foldIsGenericMainThread(A);
4021       break;
4022     case OMPRTL___kmpc_parallel_level:
4023       Changed |= foldParallelLevel(A);
4024       break;
4025     case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4026       Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
4027       break;
4028     case OMPRTL___kmpc_get_hardware_num_blocks:
4029       Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
4030       break;
4031     default:
4032       llvm_unreachable("Unhandled OpenMP runtime function!");
4033     }
4034 
4035     return Changed;
4036   }
4037 
4038   ChangeStatus manifest(Attributor &A) override {
4039     ChangeStatus Changed = ChangeStatus::UNCHANGED;
4040 
4041     if (SimplifiedValue.hasValue() && SimplifiedValue.getValue()) {
4042       Instruction &I = *getCtxI();
4043       A.changeValueAfterManifest(I, **SimplifiedValue);
4044       A.deleteAfterManifest(I);
4045 
4046       CallBase *CB = dyn_cast<CallBase>(&I);
4047       auto Remark = [&](OptimizationRemark OR) {
4048         if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
4049           return OR << "Replacing OpenMP runtime call "
4050                     << CB->getCalledFunction()->getName() << " with "
4051                     << ore::NV("FoldedValue", C->getZExtValue()) << ".";
4052         else
4053           return OR << "Replacing OpenMP runtime call "
4054                     << CB->getCalledFunction()->getName() << ".";
4055       };
4056 
4057       if (CB && EnableVerboseRemarks)
4058         A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
4059 
4060       LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
4061                         << **SimplifiedValue << "\n");
4062 
4063       Changed = ChangeStatus::CHANGED;
4064     }
4065 
4066     return Changed;
4067   }
4068 
4069   ChangeStatus indicatePessimisticFixpoint() override {
4070     SimplifiedValue = nullptr;
4071     return AAFoldRuntimeCall::indicatePessimisticFixpoint();
4072   }
4073 
4074 private:
4075   /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
4076   ChangeStatus foldIsSPMDExecMode(Attributor &A) {
4077     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4078 
4079     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4080     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4081     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4082         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4083 
4084     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4085       return indicatePessimisticFixpoint();
4086 
4087     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4088       auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4089                                           DepClassTy::REQUIRED);
4090 
4091       if (!AA.isValidState()) {
4092         SimplifiedValue = nullptr;
4093         return indicatePessimisticFixpoint();
4094       }
4095 
4096       if (AA.SPMDCompatibilityTracker.isAssumed()) {
4097         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4098           ++KnownSPMDCount;
4099         else
4100           ++AssumedSPMDCount;
4101       } else {
4102         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4103           ++KnownNonSPMDCount;
4104         else
4105           ++AssumedNonSPMDCount;
4106       }
4107     }
4108 
4109     if ((AssumedSPMDCount + KnownSPMDCount) &&
4110         (AssumedNonSPMDCount + KnownNonSPMDCount))
4111       return indicatePessimisticFixpoint();
4112 
4113     auto &Ctx = getAnchorValue().getContext();
4114     if (KnownSPMDCount || AssumedSPMDCount) {
4115       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4116              "Expected only SPMD kernels!");
4117       // All reaching kernels are in SPMD mode. Update all function calls to
4118       // __kmpc_is_spmd_exec_mode to 1.
4119       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4120     } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
4121       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4122              "Expected only non-SPMD kernels!");
4123       // All reaching kernels are in non-SPMD mode. Update all function
4124       // calls to __kmpc_is_spmd_exec_mode to 0.
4125       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
4126     } else {
4127       // We have empty reaching kernels, therefore we cannot tell if the
4128       // associated call site can be folded. At this moment, SimplifiedValue
4129       // must be none.
4130       assert(!SimplifiedValue.hasValue() && "SimplifiedValue should be none");
4131     }
4132 
4133     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4134                                                     : ChangeStatus::CHANGED;
4135   }
4136 
4137   /// Fold __kmpc_is_generic_main_thread_id into a constant if possible.
4138   ChangeStatus foldIsGenericMainThread(Attributor &A) {
4139     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4140 
4141     CallBase &CB = cast<CallBase>(getAssociatedValue());
4142     Function *F = CB.getFunction();
4143     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
4144         *this, IRPosition::function(*F), DepClassTy::REQUIRED);
4145 
4146     if (!ExecutionDomainAA.isValidState())
4147       return indicatePessimisticFixpoint();
4148 
4149     auto &Ctx = getAnchorValue().getContext();
4150     if (ExecutionDomainAA.isExecutedByInitialThreadOnly(CB))
4151       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4152     else
4153       return indicatePessimisticFixpoint();
4154 
4155     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4156                                                     : ChangeStatus::CHANGED;
4157   }
4158 
4159   /// Fold __kmpc_parallel_level into a constant if possible.
4160   ChangeStatus foldParallelLevel(Attributor &A) {
4161     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4162 
4163     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4164         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4165 
4166     if (!CallerKernelInfoAA.ParallelLevels.isValidState())
4167       return indicatePessimisticFixpoint();
4168 
4169     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4170       return indicatePessimisticFixpoint();
4171 
4172     if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
4173       assert(!SimplifiedValue.hasValue() &&
4174              "SimplifiedValue should keep none at this point");
4175       return ChangeStatus::UNCHANGED;
4176     }
4177 
4178     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4179     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4180     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4181       auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4182                                           DepClassTy::REQUIRED);
4183       if (!AA.SPMDCompatibilityTracker.isValidState())
4184         return indicatePessimisticFixpoint();
4185 
4186       if (AA.SPMDCompatibilityTracker.isAssumed()) {
4187         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4188           ++KnownSPMDCount;
4189         else
4190           ++AssumedSPMDCount;
4191       } else {
4192         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4193           ++KnownNonSPMDCount;
4194         else
4195           ++AssumedNonSPMDCount;
4196       }
4197     }
4198 
4199     if ((AssumedSPMDCount + KnownSPMDCount) &&
4200         (AssumedNonSPMDCount + KnownNonSPMDCount))
4201       return indicatePessimisticFixpoint();
4202 
4203     auto &Ctx = getAnchorValue().getContext();
4204     // If the caller can only be reached by SPMD kernel entries, the parallel
4205     // level is 1. Similarly, if the caller can only be reached by non-SPMD
4206     // kernel entries, it is 0.
4207     if (AssumedSPMDCount || KnownSPMDCount) {
4208       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4209              "Expected only SPMD kernels!");
4210       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
4211     } else {
4212       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4213              "Expected only non-SPMD kernels!");
4214       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
4215     }
4216     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4217                                                     : ChangeStatus::CHANGED;
4218   }
4219 
4220   ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
4221     // Specialize only if all the calls agree with the attribute constant value
4222     int32_t CurrentAttrValue = -1;
4223     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4224 
4225     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4226         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4227 
4228     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4229       return indicatePessimisticFixpoint();
4230 
4231     // Iterate over the kernels that reach this function
4232     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4233       int32_t NextAttrVal = -1;
4234       if (K->hasFnAttribute(Attr))
4235         NextAttrVal =
4236             std::stoi(K->getFnAttribute(Attr).getValueAsString().str());
4237 
4238       if (NextAttrVal == -1 ||
4239           (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
4240         return indicatePessimisticFixpoint();
4241       CurrentAttrValue = NextAttrVal;
4242     }
4243 
4244     if (CurrentAttrValue != -1) {
4245       auto &Ctx = getAnchorValue().getContext();
4246       SimplifiedValue =
4247           ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
4248     }
4249     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4250                                                     : ChangeStatus::CHANGED;
4251   }
4252 
4253   /// An optional value the associated value is assumed to fold to. That is, we
4254   /// assume the associated value (which is a call) can be replaced by this
4255   /// simplified value.
4256   Optional<Value *> SimplifiedValue;
4257 
4258   /// The runtime function kind of the callee of the associated call site.
4259   RuntimeFunction RFKind;
4260 };
4261 
4262 } // namespace
4263 
4264 /// Register folding callsite
4265 void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
4266   auto &RFI = OMPInfoCache.RFIs[RF];
4267   RFI.foreachUse(SCC, [&](Use &U, Function &F) {
4268     CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
4269     if (!CI)
4270       return false;
4271     A.getOrCreateAAFor<AAFoldRuntimeCall>(
4272         IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
4273         DepClassTy::NONE, /* ForceUpdate */ false,
4274         /* UpdateAfterInit */ false);
4275     return false;
4276   });
4277 }
4278 
4279 void OpenMPOpt::registerAAs(bool IsModulePass) {
4280   if (SCC.empty())
4281 
4282     return;
4283   if (IsModulePass) {
4284     // Ensure we create the AAKernelInfo AAs first and without triggering an
4285     // update. This will make sure we register all value simplification
4286     // callbacks before any other AA has the chance to create an AAValueSimplify
4287     // or similar.
4288     for (Function *Kernel : OMPInfoCache.Kernels)
4289       A.getOrCreateAAFor<AAKernelInfo>(
4290           IRPosition::function(*Kernel), /* QueryingAA */ nullptr,
4291           DepClassTy::NONE, /* ForceUpdate */ false,
4292           /* UpdateAfterInit */ false);
4293 
4294     registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id);
4295     registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
4296     registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
4297     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
4298     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
4299   }
4300 
4301   // Create CallSite AA for all Getters.
4302   for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
4303     auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
4304 
4305     auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
4306 
4307     auto CreateAA = [&](Use &U, Function &Caller) {
4308       CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
4309       if (!CI)
4310         return false;
4311 
4312       auto &CB = cast<CallBase>(*CI);
4313 
4314       IRPosition CBPos = IRPosition::callsite_function(CB);
4315       A.getOrCreateAAFor<AAICVTracker>(CBPos);
4316       return false;
4317     };
4318 
4319     GetterRFI.foreachUse(SCC, CreateAA);
4320   }
4321   auto &GlobalizationRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4322   auto CreateAA = [&](Use &U, Function &F) {
4323     A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
4324     return false;
4325   };
4326   if (!DisableOpenMPOptDeglobalization)
4327     GlobalizationRFI.foreachUse(SCC, CreateAA);
4328 
4329   // Create an ExecutionDomain AA for every function and a HeapToStack AA for
4330   // every function if there is a device kernel.
4331   if (!isOpenMPDevice(M))
4332     return;
4333 
4334   for (auto *F : SCC) {
4335     if (F->isDeclaration())
4336       continue;
4337 
4338     A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F));
4339     if (!DisableOpenMPOptDeglobalization)
4340       A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F));
4341 
4342     for (auto &I : instructions(*F)) {
4343       if (auto *LI = dyn_cast<LoadInst>(&I)) {
4344         bool UsedAssumedInformation = false;
4345         A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
4346                                UsedAssumedInformation);
4347       }
4348     }
4349   }
4350 }
4351 
4352 const char AAICVTracker::ID = 0;
4353 const char AAKernelInfo::ID = 0;
4354 const char AAExecutionDomain::ID = 0;
4355 const char AAHeapToShared::ID = 0;
4356 const char AAFoldRuntimeCall::ID = 0;
4357 
4358 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
4359                                               Attributor &A) {
4360   AAICVTracker *AA = nullptr;
4361   switch (IRP.getPositionKind()) {
4362   case IRPosition::IRP_INVALID:
4363   case IRPosition::IRP_FLOAT:
4364   case IRPosition::IRP_ARGUMENT:
4365   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4366     llvm_unreachable("ICVTracker can only be created for function position!");
4367   case IRPosition::IRP_RETURNED:
4368     AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
4369     break;
4370   case IRPosition::IRP_CALL_SITE_RETURNED:
4371     AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
4372     break;
4373   case IRPosition::IRP_CALL_SITE:
4374     AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
4375     break;
4376   case IRPosition::IRP_FUNCTION:
4377     AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
4378     break;
4379   }
4380 
4381   return *AA;
4382 }
4383 
4384 AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP,
4385                                                         Attributor &A) {
4386   AAExecutionDomainFunction *AA = nullptr;
4387   switch (IRP.getPositionKind()) {
4388   case IRPosition::IRP_INVALID:
4389   case IRPosition::IRP_FLOAT:
4390   case IRPosition::IRP_ARGUMENT:
4391   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4392   case IRPosition::IRP_RETURNED:
4393   case IRPosition::IRP_CALL_SITE_RETURNED:
4394   case IRPosition::IRP_CALL_SITE:
4395     llvm_unreachable(
4396         "AAExecutionDomain can only be created for function position!");
4397   case IRPosition::IRP_FUNCTION:
4398     AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
4399     break;
4400   }
4401 
4402   return *AA;
4403 }
4404 
4405 AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
4406                                                   Attributor &A) {
4407   AAHeapToSharedFunction *AA = nullptr;
4408   switch (IRP.getPositionKind()) {
4409   case IRPosition::IRP_INVALID:
4410   case IRPosition::IRP_FLOAT:
4411   case IRPosition::IRP_ARGUMENT:
4412   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4413   case IRPosition::IRP_RETURNED:
4414   case IRPosition::IRP_CALL_SITE_RETURNED:
4415   case IRPosition::IRP_CALL_SITE:
4416     llvm_unreachable(
4417         "AAHeapToShared can only be created for function position!");
4418   case IRPosition::IRP_FUNCTION:
4419     AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
4420     break;
4421   }
4422 
4423   return *AA;
4424 }
4425 
4426 AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
4427                                               Attributor &A) {
4428   AAKernelInfo *AA = nullptr;
4429   switch (IRP.getPositionKind()) {
4430   case IRPosition::IRP_INVALID:
4431   case IRPosition::IRP_FLOAT:
4432   case IRPosition::IRP_ARGUMENT:
4433   case IRPosition::IRP_RETURNED:
4434   case IRPosition::IRP_CALL_SITE_RETURNED:
4435   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4436     llvm_unreachable("KernelInfo can only be created for function position!");
4437   case IRPosition::IRP_CALL_SITE:
4438     AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
4439     break;
4440   case IRPosition::IRP_FUNCTION:
4441     AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
4442     break;
4443   }
4444 
4445   return *AA;
4446 }
4447 
4448 AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
4449                                                         Attributor &A) {
4450   AAFoldRuntimeCall *AA = nullptr;
4451   switch (IRP.getPositionKind()) {
4452   case IRPosition::IRP_INVALID:
4453   case IRPosition::IRP_FLOAT:
4454   case IRPosition::IRP_ARGUMENT:
4455   case IRPosition::IRP_RETURNED:
4456   case IRPosition::IRP_FUNCTION:
4457   case IRPosition::IRP_CALL_SITE:
4458   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4459     llvm_unreachable("KernelInfo can only be created for call site position!");
4460   case IRPosition::IRP_CALL_SITE_RETURNED:
4461     AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
4462     break;
4463   }
4464 
4465   return *AA;
4466 }
4467 
4468 PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
4469   if (!containsOpenMP(M))
4470     return PreservedAnalyses::all();
4471   if (DisableOpenMPOptimizations)
4472     return PreservedAnalyses::all();
4473 
4474   FunctionAnalysisManager &FAM =
4475       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
4476   KernelSet Kernels = getDeviceKernels(M);
4477 
4478   auto IsCalled = [&](Function &F) {
4479     if (Kernels.contains(&F))
4480       return true;
4481     for (const User *U : F.users())
4482       if (!isa<BlockAddress>(U))
4483         return true;
4484     return false;
4485   };
4486 
4487   auto EmitRemark = [&](Function &F) {
4488     auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
4489     ORE.emit([&]() {
4490       OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
4491       return ORA << "Could not internalize function. "
4492                  << "Some optimizations may not be possible. [OMP140]";
4493     });
4494   };
4495 
4496   // Create internal copies of each function if this is a kernel Module. This
4497   // allows iterprocedural passes to see every call edge.
4498   DenseMap<Function *, Function *> InternalizedMap;
4499   if (isOpenMPDevice(M)) {
4500     SmallPtrSet<Function *, 16> InternalizeFns;
4501     for (Function &F : M)
4502       if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
4503           !DisableInternalization) {
4504         if (Attributor::isInternalizable(F)) {
4505           InternalizeFns.insert(&F);
4506         } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
4507           EmitRemark(F);
4508         }
4509       }
4510 
4511     Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
4512   }
4513 
4514   // Look at every function in the Module unless it was internalized.
4515   SmallVector<Function *, 16> SCC;
4516   for (Function &F : M)
4517     if (!F.isDeclaration() && !InternalizedMap.lookup(&F))
4518       SCC.push_back(&F);
4519 
4520   if (SCC.empty())
4521     return PreservedAnalyses::all();
4522 
4523   AnalysisGetter AG(FAM);
4524 
4525   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
4526     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
4527   };
4528 
4529   BumpPtrAllocator Allocator;
4530   CallGraphUpdater CGUpdater;
4531 
4532   SetVector<Function *> Functions(SCC.begin(), SCC.end());
4533   OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels);
4534 
4535   unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4536   Attributor A(Functions, InfoCache, CGUpdater, nullptr, true, false,
4537                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4538 
4539   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4540   bool Changed = OMPOpt.run(true);
4541 
4542   // Optionally inline device functions for potentially better performance.
4543   if (AlwaysInlineDeviceFunctions && isOpenMPDevice(M))
4544     for (Function &F : M)
4545       if (!F.isDeclaration() && !Kernels.contains(&F) &&
4546           !F.hasFnAttribute(Attribute::NoInline))
4547         F.addFnAttr(Attribute::AlwaysInline);
4548 
4549   if (PrintModuleAfterOptimizations)
4550     LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
4551 
4552   if (Changed)
4553     return PreservedAnalyses::none();
4554 
4555   return PreservedAnalyses::all();
4556 }
4557 
4558 PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
4559                                           CGSCCAnalysisManager &AM,
4560                                           LazyCallGraph &CG,
4561                                           CGSCCUpdateResult &UR) {
4562   if (!containsOpenMP(*C.begin()->getFunction().getParent()))
4563     return PreservedAnalyses::all();
4564   if (DisableOpenMPOptimizations)
4565     return PreservedAnalyses::all();
4566 
4567   SmallVector<Function *, 16> SCC;
4568   // If there are kernels in the module, we have to run on all SCC's.
4569   for (LazyCallGraph::Node &N : C) {
4570     Function *Fn = &N.getFunction();
4571     SCC.push_back(Fn);
4572   }
4573 
4574   if (SCC.empty())
4575     return PreservedAnalyses::all();
4576 
4577   Module &M = *C.begin()->getFunction().getParent();
4578 
4579   KernelSet Kernels = getDeviceKernels(M);
4580 
4581   FunctionAnalysisManager &FAM =
4582       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
4583 
4584   AnalysisGetter AG(FAM);
4585 
4586   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
4587     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
4588   };
4589 
4590   BumpPtrAllocator Allocator;
4591   CallGraphUpdater CGUpdater;
4592   CGUpdater.initialize(CG, C, AM, UR);
4593 
4594   SetVector<Function *> Functions(SCC.begin(), SCC.end());
4595   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
4596                                 /*CGSCC*/ Functions, Kernels);
4597 
4598   unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4599   Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
4600                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4601 
4602   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4603   bool Changed = OMPOpt.run(false);
4604 
4605   if (PrintModuleAfterOptimizations)
4606     LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
4607 
4608   if (Changed)
4609     return PreservedAnalyses::none();
4610 
4611   return PreservedAnalyses::all();
4612 }
4613 
4614 namespace {
4615 
4616 struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass {
4617   CallGraphUpdater CGUpdater;
4618   static char ID;
4619 
4620   OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) {
4621     initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry());
4622   }
4623 
4624   void getAnalysisUsage(AnalysisUsage &AU) const override {
4625     CallGraphSCCPass::getAnalysisUsage(AU);
4626   }
4627 
4628   bool runOnSCC(CallGraphSCC &CGSCC) override {
4629     if (!containsOpenMP(CGSCC.getCallGraph().getModule()))
4630       return false;
4631     if (DisableOpenMPOptimizations || skipSCC(CGSCC))
4632       return false;
4633 
4634     SmallVector<Function *, 16> SCC;
4635     // If there are kernels in the module, we have to run on all SCC's.
4636     for (CallGraphNode *CGN : CGSCC) {
4637       Function *Fn = CGN->getFunction();
4638       if (!Fn || Fn->isDeclaration())
4639         continue;
4640       SCC.push_back(Fn);
4641     }
4642 
4643     if (SCC.empty())
4644       return false;
4645 
4646     Module &M = CGSCC.getCallGraph().getModule();
4647     KernelSet Kernels = getDeviceKernels(M);
4648 
4649     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
4650     CGUpdater.initialize(CG, CGSCC);
4651 
4652     // Maintain a map of functions to avoid rebuilding the ORE
4653     DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap;
4654     auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
4655       std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
4656       if (!ORE)
4657         ORE = std::make_unique<OptimizationRemarkEmitter>(F);
4658       return *ORE;
4659     };
4660 
4661     AnalysisGetter AG;
4662     SetVector<Function *> Functions(SCC.begin(), SCC.end());
4663     BumpPtrAllocator Allocator;
4664     OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG,
4665                                   Allocator,
4666                                   /*CGSCC*/ Functions, Kernels);
4667 
4668     unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4669     Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
4670                  MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4671 
4672     OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4673     bool Result = OMPOpt.run(false);
4674 
4675     if (PrintModuleAfterOptimizations)
4676       LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
4677 
4678     return Result;
4679   }
4680 
4681   bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
4682 };
4683 
4684 } // end anonymous namespace
4685 
4686 KernelSet llvm::omp::getDeviceKernels(Module &M) {
4687   // TODO: Create a more cross-platform way of determining device kernels.
4688   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
4689   KernelSet Kernels;
4690 
4691   if (!MD)
4692     return Kernels;
4693 
4694   for (auto *Op : MD->operands()) {
4695     if (Op->getNumOperands() < 2)
4696       continue;
4697     MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
4698     if (!KindID || KindID->getString() != "kernel")
4699       continue;
4700 
4701     Function *KernelFn =
4702         mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
4703     if (!KernelFn)
4704       continue;
4705 
4706     ++NumOpenMPTargetRegionKernels;
4707 
4708     Kernels.insert(KernelFn);
4709   }
4710 
4711   return Kernels;
4712 }
4713 
4714 bool llvm::omp::containsOpenMP(Module &M) {
4715   Metadata *MD = M.getModuleFlag("openmp");
4716   if (!MD)
4717     return false;
4718 
4719   return true;
4720 }
4721 
4722 bool llvm::omp::isOpenMPDevice(Module &M) {
4723   Metadata *MD = M.getModuleFlag("openmp-device");
4724   if (!MD)
4725     return false;
4726 
4727   return true;
4728 }
4729 
4730 char OpenMPOptCGSCCLegacyPass::ID = 0;
4731 
4732 INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
4733                       "OpenMP specific optimizations", false, false)
4734 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
4735 INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
4736                     "OpenMP specific optimizations", false, false)
4737 
4738 Pass *llvm::createOpenMPOptCGSCCLegacyPass() {
4739   return new OpenMPOptCGSCCLegacyPass();
4740 }
4741