1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpenMP specific optimizations:
10 //
11 // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12 // - Replacing globalized device memory with stack memory.
13 // - Replacing globalized device memory with shared memory.
14 // - Parallel region merging.
15 // - Transforming generic-mode device kernels to SPMD mode.
16 // - Specializing the state machine for generic-mode device kernels.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "llvm/Transforms/IPO/OpenMPOpt.h"
21 
22 #include "llvm/ADT/EnumeratedArray.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/Statistic.h"
25 #include "llvm/Analysis/CallGraph.h"
26 #include "llvm/Analysis/CallGraphSCCPass.h"
27 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
28 #include "llvm/Analysis/ValueTracking.h"
29 #include "llvm/Frontend/OpenMP/OMPConstants.h"
30 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
31 #include "llvm/IR/Assumptions.h"
32 #include "llvm/IR/DiagnosticInfo.h"
33 #include "llvm/IR/GlobalValue.h"
34 #include "llvm/IR/Instruction.h"
35 #include "llvm/IR/IntrinsicInst.h"
36 #include "llvm/IR/IntrinsicsAMDGPU.h"
37 #include "llvm/IR/IntrinsicsNVPTX.h"
38 #include "llvm/InitializePasses.h"
39 #include "llvm/Support/CommandLine.h"
40 #include "llvm/Transforms/IPO.h"
41 #include "llvm/Transforms/IPO/Attributor.h"
42 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
43 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
44 #include "llvm/Transforms/Utils/CodeExtractor.h"
45 
46 using namespace llvm;
47 using namespace omp;
48 
49 #define DEBUG_TYPE "openmp-opt"
50 
51 static cl::opt<bool> DisableOpenMPOptimizations(
52     "openmp-opt-disable", cl::ZeroOrMore,
53     cl::desc("Disable OpenMP specific optimizations."), cl::Hidden,
54     cl::init(false));
55 
56 static cl::opt<bool> EnableParallelRegionMerging(
57     "openmp-opt-enable-merging", cl::ZeroOrMore,
58     cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
59     cl::init(false));
60 
61 static cl::opt<bool>
62     DisableInternalization("openmp-opt-disable-internalization", cl::ZeroOrMore,
63                            cl::desc("Disable function internalization."),
64                            cl::Hidden, cl::init(false));
65 
66 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
67                                     cl::Hidden);
68 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
69                                         cl::init(false), cl::Hidden);
70 
71 static cl::opt<bool> HideMemoryTransferLatency(
72     "openmp-hide-memory-transfer-latency",
73     cl::desc("[WIP] Tries to hide the latency of host to device memory"
74              " transfers"),
75     cl::Hidden, cl::init(false));
76 
77 static cl::opt<bool> DisableOpenMPOptDeglobalization(
78     "openmp-opt-disable-deglobalization", cl::ZeroOrMore,
79     cl::desc("Disable OpenMP optimizations involving deglobalization."),
80     cl::Hidden, cl::init(false));
81 
82 static cl::opt<bool> DisableOpenMPOptSPMDization(
83     "openmp-opt-disable-spmdization", cl::ZeroOrMore,
84     cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
85     cl::Hidden, cl::init(false));
86 
87 static cl::opt<bool> DisableOpenMPOptFolding(
88     "openmp-opt-disable-folding", cl::ZeroOrMore,
89     cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
90     cl::init(false));
91 
92 static cl::opt<bool> DisableOpenMPOptStateMachineRewrite(
93     "openmp-opt-disable-state-machine-rewrite", cl::ZeroOrMore,
94     cl::desc("Disable OpenMP optimizations that replace the state machine."),
95     cl::Hidden, cl::init(false));
96 
97 static cl::opt<bool> PrintModuleAfterOptimizations(
98     "openmp-opt-print-module", cl::ZeroOrMore,
99     cl::desc("Print the current module after OpenMP optimizations."),
100     cl::Hidden, cl::init(false));
101 
102 static cl::opt<bool> AlwaysInlineDeviceFunctions(
103     "openmp-opt-inline-device", cl::ZeroOrMore,
104     cl::desc("Inline all applicible functions on the device."), cl::Hidden,
105     cl::init(false));
106 
107 static cl::opt<bool>
108     EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::ZeroOrMore,
109                          cl::desc("Enables more verbose remarks."), cl::Hidden,
110                          cl::init(false));
111 
112 STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
113           "Number of OpenMP runtime calls deduplicated");
114 STATISTIC(NumOpenMPParallelRegionsDeleted,
115           "Number of OpenMP parallel regions deleted");
116 STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
117           "Number of OpenMP runtime functions identified");
118 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
119           "Number of OpenMP runtime function uses identified");
120 STATISTIC(NumOpenMPTargetRegionKernels,
121           "Number of OpenMP target region entry points (=kernels) identified");
122 STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
123           "Number of OpenMP target region entry points (=kernels) executed in "
124           "SPMD-mode instead of generic-mode");
125 STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
126           "Number of OpenMP target region entry points (=kernels) executed in "
127           "generic-mode without a state machines");
128 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
129           "Number of OpenMP target region entry points (=kernels) executed in "
130           "generic-mode with customized state machines with fallback");
131 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
132           "Number of OpenMP target region entry points (=kernels) executed in "
133           "generic-mode with customized state machines without fallback");
134 STATISTIC(
135     NumOpenMPParallelRegionsReplacedInGPUStateMachine,
136     "Number of OpenMP parallel regions replaced with ID in GPU state machines");
137 STATISTIC(NumOpenMPParallelRegionsMerged,
138           "Number of OpenMP parallel regions merged");
139 STATISTIC(NumBytesMovedToSharedMemory,
140           "Amount of memory pushed to shared memory");
141 
142 #if !defined(NDEBUG)
143 static constexpr auto TAG = "[" DEBUG_TYPE "]";
144 #endif
145 
146 namespace {
147 
148 enum class AddressSpace : unsigned {
149   Generic = 0,
150   Global = 1,
151   Shared = 3,
152   Constant = 4,
153   Local = 5,
154 };
155 
156 struct AAHeapToShared;
157 
158 struct AAICVTracker;
159 
160 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
161 /// Attributor runs.
162 struct OMPInformationCache : public InformationCache {
163   OMPInformationCache(Module &M, AnalysisGetter &AG,
164                       BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC,
165                       SmallPtrSetImpl<Kernel> &Kernels)
166       : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M),
167         Kernels(Kernels) {
168 
169     OMPBuilder.initialize();
170     initializeRuntimeFunctions();
171     initializeInternalControlVars();
172   }
173 
174   /// Generic information that describes an internal control variable.
175   struct InternalControlVarInfo {
176     /// The kind, as described by InternalControlVar enum.
177     InternalControlVar Kind;
178 
179     /// The name of the ICV.
180     StringRef Name;
181 
182     /// Environment variable associated with this ICV.
183     StringRef EnvVarName;
184 
185     /// Initial value kind.
186     ICVInitValue InitKind;
187 
188     /// Initial value.
189     ConstantInt *InitValue;
190 
191     /// Setter RTL function associated with this ICV.
192     RuntimeFunction Setter;
193 
194     /// Getter RTL function associated with this ICV.
195     RuntimeFunction Getter;
196 
197     /// RTL Function corresponding to the override clause of this ICV
198     RuntimeFunction Clause;
199   };
200 
201   /// Generic information that describes a runtime function
202   struct RuntimeFunctionInfo {
203 
204     /// The kind, as described by the RuntimeFunction enum.
205     RuntimeFunction Kind;
206 
207     /// The name of the function.
208     StringRef Name;
209 
210     /// Flag to indicate a variadic function.
211     bool IsVarArg;
212 
213     /// The return type of the function.
214     Type *ReturnType;
215 
216     /// The argument types of the function.
217     SmallVector<Type *, 8> ArgumentTypes;
218 
219     /// The declaration if available.
220     Function *Declaration = nullptr;
221 
222     /// Uses of this runtime function per function containing the use.
223     using UseVector = SmallVector<Use *, 16>;
224 
225     /// Clear UsesMap for runtime function.
226     void clearUsesMap() { UsesMap.clear(); }
227 
228     /// Boolean conversion that is true if the runtime function was found.
229     operator bool() const { return Declaration; }
230 
231     /// Return the vector of uses in function \p F.
232     UseVector &getOrCreateUseVector(Function *F) {
233       std::shared_ptr<UseVector> &UV = UsesMap[F];
234       if (!UV)
235         UV = std::make_shared<UseVector>();
236       return *UV;
237     }
238 
239     /// Return the vector of uses in function \p F or `nullptr` if there are
240     /// none.
241     const UseVector *getUseVector(Function &F) const {
242       auto I = UsesMap.find(&F);
243       if (I != UsesMap.end())
244         return I->second.get();
245       return nullptr;
246     }
247 
248     /// Return how many functions contain uses of this runtime function.
249     size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
250 
251     /// Return the number of arguments (or the minimal number for variadic
252     /// functions).
253     size_t getNumArgs() const { return ArgumentTypes.size(); }
254 
255     /// Run the callback \p CB on each use and forget the use if the result is
256     /// true. The callback will be fed the function in which the use was
257     /// encountered as second argument.
258     void foreachUse(SmallVectorImpl<Function *> &SCC,
259                     function_ref<bool(Use &, Function &)> CB) {
260       for (Function *F : SCC)
261         foreachUse(CB, F);
262     }
263 
264     /// Run the callback \p CB on each use within the function \p F and forget
265     /// the use if the result is true.
266     void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
267       SmallVector<unsigned, 8> ToBeDeleted;
268       ToBeDeleted.clear();
269 
270       unsigned Idx = 0;
271       UseVector &UV = getOrCreateUseVector(F);
272 
273       for (Use *U : UV) {
274         if (CB(*U, *F))
275           ToBeDeleted.push_back(Idx);
276         ++Idx;
277       }
278 
279       // Remove the to-be-deleted indices in reverse order as prior
280       // modifications will not modify the smaller indices.
281       while (!ToBeDeleted.empty()) {
282         unsigned Idx = ToBeDeleted.pop_back_val();
283         UV[Idx] = UV.back();
284         UV.pop_back();
285       }
286     }
287 
288   private:
289     /// Map from functions to all uses of this runtime function contained in
290     /// them.
291     DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
292 
293   public:
294     /// Iterators for the uses of this runtime function.
295     decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
296     decltype(UsesMap)::iterator end() { return UsesMap.end(); }
297   };
298 
299   /// An OpenMP-IR-Builder instance
300   OpenMPIRBuilder OMPBuilder;
301 
302   /// Map from runtime function kind to the runtime function description.
303   EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
304                   RuntimeFunction::OMPRTL___last>
305       RFIs;
306 
307   /// Map from function declarations/definitions to their runtime enum type.
308   DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
309 
310   /// Map from ICV kind to the ICV description.
311   EnumeratedArray<InternalControlVarInfo, InternalControlVar,
312                   InternalControlVar::ICV___last>
313       ICVs;
314 
315   /// Helper to initialize all internal control variable information for those
316   /// defined in OMPKinds.def.
317   void initializeInternalControlVars() {
318 #define ICV_RT_SET(_Name, RTL)                                                 \
319   {                                                                            \
320     auto &ICV = ICVs[_Name];                                                   \
321     ICV.Setter = RTL;                                                          \
322   }
323 #define ICV_RT_GET(Name, RTL)                                                  \
324   {                                                                            \
325     auto &ICV = ICVs[Name];                                                    \
326     ICV.Getter = RTL;                                                          \
327   }
328 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init)                           \
329   {                                                                            \
330     auto &ICV = ICVs[Enum];                                                    \
331     ICV.Name = _Name;                                                          \
332     ICV.Kind = Enum;                                                           \
333     ICV.InitKind = Init;                                                       \
334     ICV.EnvVarName = _EnvVarName;                                              \
335     switch (ICV.InitKind) {                                                    \
336     case ICV_IMPLEMENTATION_DEFINED:                                           \
337       ICV.InitValue = nullptr;                                                 \
338       break;                                                                   \
339     case ICV_ZERO:                                                             \
340       ICV.InitValue = ConstantInt::get(                                        \
341           Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0);                \
342       break;                                                                   \
343     case ICV_FALSE:                                                            \
344       ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext());    \
345       break;                                                                   \
346     case ICV_LAST:                                                             \
347       break;                                                                   \
348     }                                                                          \
349   }
350 #include "llvm/Frontend/OpenMP/OMPKinds.def"
351   }
352 
353   /// Returns true if the function declaration \p F matches the runtime
354   /// function types, that is, return type \p RTFRetType, and argument types
355   /// \p RTFArgTypes.
356   static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
357                                   SmallVector<Type *, 8> &RTFArgTypes) {
358     // TODO: We should output information to the user (under debug output
359     //       and via remarks).
360 
361     if (!F)
362       return false;
363     if (F->getReturnType() != RTFRetType)
364       return false;
365     if (F->arg_size() != RTFArgTypes.size())
366       return false;
367 
368     auto RTFTyIt = RTFArgTypes.begin();
369     for (Argument &Arg : F->args()) {
370       if (Arg.getType() != *RTFTyIt)
371         return false;
372 
373       ++RTFTyIt;
374     }
375 
376     return true;
377   }
378 
379   // Helper to collect all uses of the declaration in the UsesMap.
380   unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
381     unsigned NumUses = 0;
382     if (!RFI.Declaration)
383       return NumUses;
384     OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
385 
386     if (CollectStats) {
387       NumOpenMPRuntimeFunctionsIdentified += 1;
388       NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
389     }
390 
391     // TODO: We directly convert uses into proper calls and unknown uses.
392     for (Use &U : RFI.Declaration->uses()) {
393       if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
394         if (ModuleSlice.count(UserI->getFunction())) {
395           RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
396           ++NumUses;
397         }
398       } else {
399         RFI.getOrCreateUseVector(nullptr).push_back(&U);
400         ++NumUses;
401       }
402     }
403     return NumUses;
404   }
405 
406   // Helper function to recollect uses of a runtime function.
407   void recollectUsesForFunction(RuntimeFunction RTF) {
408     auto &RFI = RFIs[RTF];
409     RFI.clearUsesMap();
410     collectUses(RFI, /*CollectStats*/ false);
411   }
412 
413   // Helper function to recollect uses of all runtime functions.
414   void recollectUses() {
415     for (int Idx = 0; Idx < RFIs.size(); ++Idx)
416       recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
417   }
418 
419   /// Helper to initialize all runtime function information for those defined
420   /// in OpenMPKinds.def.
421   void initializeRuntimeFunctions() {
422     Module &M = *((*ModuleSlice.begin())->getParent());
423 
424     // Helper macros for handling __VA_ARGS__ in OMP_RTL
425 #define OMP_TYPE(VarName, ...)                                                 \
426   Type *VarName = OMPBuilder.VarName;                                          \
427   (void)VarName;
428 
429 #define OMP_ARRAY_TYPE(VarName, ...)                                           \
430   ArrayType *VarName##Ty = OMPBuilder.VarName##Ty;                             \
431   (void)VarName##Ty;                                                           \
432   PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy;                     \
433   (void)VarName##PtrTy;
434 
435 #define OMP_FUNCTION_TYPE(VarName, ...)                                        \
436   FunctionType *VarName = OMPBuilder.VarName;                                  \
437   (void)VarName;                                                               \
438   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
439   (void)VarName##Ptr;
440 
441 #define OMP_STRUCT_TYPE(VarName, ...)                                          \
442   StructType *VarName = OMPBuilder.VarName;                                    \
443   (void)VarName;                                                               \
444   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
445   (void)VarName##Ptr;
446 
447 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
448   {                                                                            \
449     SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
450     Function *F = M.getFunction(_Name);                                        \
451     RTLFunctions.insert(F);                                                    \
452     if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) {           \
453       RuntimeFunctionIDMap[F] = _Enum;                                         \
454       F->removeFnAttr(Attribute::NoInline);                                    \
455       auto &RFI = RFIs[_Enum];                                                 \
456       RFI.Kind = _Enum;                                                        \
457       RFI.Name = _Name;                                                        \
458       RFI.IsVarArg = _IsVarArg;                                                \
459       RFI.ReturnType = OMPBuilder._ReturnType;                                 \
460       RFI.ArgumentTypes = std::move(ArgsTypes);                                \
461       RFI.Declaration = F;                                                     \
462       unsigned NumUses = collectUses(RFI);                                     \
463       (void)NumUses;                                                           \
464       LLVM_DEBUG({                                                             \
465         dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
466                << " found\n";                                                  \
467         if (RFI.Declaration)                                                   \
468           dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
469                  << RFI.getNumFunctionsWithUses()                              \
470                  << " different functions.\n";                                 \
471       });                                                                      \
472     }                                                                          \
473   }
474 #include "llvm/Frontend/OpenMP/OMPKinds.def"
475 
476     // TODO: We should attach the attributes defined in OMPKinds.def.
477   }
478 
479   /// Collection of known kernels (\see Kernel) in the module.
480   SmallPtrSetImpl<Kernel> &Kernels;
481 
482   /// Collection of known OpenMP runtime functions..
483   DenseSet<const Function *> RTLFunctions;
484 };
485 
486 template <typename Ty, bool InsertInvalidates = true>
487 struct BooleanStateWithSetVector : public BooleanState {
488   bool contains(const Ty &Elem) const { return Set.contains(Elem); }
489   bool insert(const Ty &Elem) {
490     if (InsertInvalidates)
491       BooleanState::indicatePessimisticFixpoint();
492     return Set.insert(Elem);
493   }
494 
495   const Ty &operator[](int Idx) const { return Set[Idx]; }
496   bool operator==(const BooleanStateWithSetVector &RHS) const {
497     return BooleanState::operator==(RHS) && Set == RHS.Set;
498   }
499   bool operator!=(const BooleanStateWithSetVector &RHS) const {
500     return !(*this == RHS);
501   }
502 
503   bool empty() const { return Set.empty(); }
504   size_t size() const { return Set.size(); }
505 
506   /// "Clamp" this state with \p RHS.
507   BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
508     BooleanState::operator^=(RHS);
509     Set.insert(RHS.Set.begin(), RHS.Set.end());
510     return *this;
511   }
512 
513 private:
514   /// A set to keep track of elements.
515   SetVector<Ty> Set;
516 
517 public:
518   typename decltype(Set)::iterator begin() { return Set.begin(); }
519   typename decltype(Set)::iterator end() { return Set.end(); }
520   typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
521   typename decltype(Set)::const_iterator end() const { return Set.end(); }
522 };
523 
524 template <typename Ty, bool InsertInvalidates = true>
525 using BooleanStateWithPtrSetVector =
526     BooleanStateWithSetVector<Ty *, InsertInvalidates>;
527 
528 struct KernelInfoState : AbstractState {
529   /// Flag to track if we reached a fixpoint.
530   bool IsAtFixpoint = false;
531 
532   /// The parallel regions (identified by the outlined parallel functions) that
533   /// can be reached from the associated function.
534   BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false>
535       ReachedKnownParallelRegions;
536 
537   /// State to track what parallel region we might reach.
538   BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
539 
540   /// State to track if we are in SPMD-mode, assumed or know, and why we decided
541   /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
542   /// false.
543   BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
544 
545   /// The __kmpc_target_init call in this kernel, if any. If we find more than
546   /// one we abort as the kernel is malformed.
547   CallBase *KernelInitCB = nullptr;
548 
549   /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
550   /// one we abort as the kernel is malformed.
551   CallBase *KernelDeinitCB = nullptr;
552 
553   /// Flag to indicate if the associated function is a kernel entry.
554   bool IsKernelEntry = false;
555 
556   /// State to track what kernel entries can reach the associated function.
557   BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
558 
559   /// State to indicate if we can track parallel level of the associated
560   /// function. We will give up tracking if we encounter unknown caller or the
561   /// caller is __kmpc_parallel_51.
562   BooleanStateWithSetVector<uint8_t> ParallelLevels;
563 
564   /// Abstract State interface
565   ///{
566 
567   KernelInfoState() {}
568   KernelInfoState(bool BestState) {
569     if (!BestState)
570       indicatePessimisticFixpoint();
571   }
572 
573   /// See AbstractState::isValidState(...)
574   bool isValidState() const override { return true; }
575 
576   /// See AbstractState::isAtFixpoint(...)
577   bool isAtFixpoint() const override { return IsAtFixpoint; }
578 
579   /// See AbstractState::indicatePessimisticFixpoint(...)
580   ChangeStatus indicatePessimisticFixpoint() override {
581     IsAtFixpoint = true;
582     SPMDCompatibilityTracker.indicatePessimisticFixpoint();
583     ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
584     return ChangeStatus::CHANGED;
585   }
586 
587   /// See AbstractState::indicateOptimisticFixpoint(...)
588   ChangeStatus indicateOptimisticFixpoint() override {
589     IsAtFixpoint = true;
590     return ChangeStatus::UNCHANGED;
591   }
592 
593   /// Return the assumed state
594   KernelInfoState &getAssumed() { return *this; }
595   const KernelInfoState &getAssumed() const { return *this; }
596 
597   bool operator==(const KernelInfoState &RHS) const {
598     if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
599       return false;
600     if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
601       return false;
602     if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
603       return false;
604     if (ReachingKernelEntries != RHS.ReachingKernelEntries)
605       return false;
606     return true;
607   }
608 
609   /// Returns true if this kernel contains any OpenMP parallel regions.
610   bool mayContainParallelRegion() {
611     return !ReachedKnownParallelRegions.empty() ||
612            !ReachedUnknownParallelRegions.empty();
613   }
614 
615   /// Return empty set as the best state of potential values.
616   static KernelInfoState getBestState() { return KernelInfoState(true); }
617 
618   static KernelInfoState getBestState(KernelInfoState &KIS) {
619     return getBestState();
620   }
621 
622   /// Return full set as the worst state of potential values.
623   static KernelInfoState getWorstState() { return KernelInfoState(false); }
624 
625   /// "Clamp" this state with \p KIS.
626   KernelInfoState operator^=(const KernelInfoState &KIS) {
627     // Do not merge two different _init and _deinit call sites.
628     if (KIS.KernelInitCB) {
629       if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
630         indicatePessimisticFixpoint();
631       KernelInitCB = KIS.KernelInitCB;
632     }
633     if (KIS.KernelDeinitCB) {
634       if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
635         indicatePessimisticFixpoint();
636       KernelDeinitCB = KIS.KernelDeinitCB;
637     }
638     SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
639     ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
640     ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
641     return *this;
642   }
643 
644   KernelInfoState operator&=(const KernelInfoState &KIS) {
645     return (*this ^= KIS);
646   }
647 
648   ///}
649 };
650 
651 /// Used to map the values physically (in the IR) stored in an offload
652 /// array, to a vector in memory.
653 struct OffloadArray {
654   /// Physical array (in the IR).
655   AllocaInst *Array = nullptr;
656   /// Mapped values.
657   SmallVector<Value *, 8> StoredValues;
658   /// Last stores made in the offload array.
659   SmallVector<StoreInst *, 8> LastAccesses;
660 
661   OffloadArray() = default;
662 
663   /// Initializes the OffloadArray with the values stored in \p Array before
664   /// instruction \p Before is reached. Returns false if the initialization
665   /// fails.
666   /// This MUST be used immediately after the construction of the object.
667   bool initialize(AllocaInst &Array, Instruction &Before) {
668     if (!Array.getAllocatedType()->isArrayTy())
669       return false;
670 
671     if (!getValues(Array, Before))
672       return false;
673 
674     this->Array = &Array;
675     return true;
676   }
677 
678   static const unsigned DeviceIDArgNum = 1;
679   static const unsigned BasePtrsArgNum = 3;
680   static const unsigned PtrsArgNum = 4;
681   static const unsigned SizesArgNum = 5;
682 
683 private:
684   /// Traverses the BasicBlock where \p Array is, collecting the stores made to
685   /// \p Array, leaving StoredValues with the values stored before the
686   /// instruction \p Before is reached.
687   bool getValues(AllocaInst &Array, Instruction &Before) {
688     // Initialize container.
689     const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
690     StoredValues.assign(NumValues, nullptr);
691     LastAccesses.assign(NumValues, nullptr);
692 
693     // TODO: This assumes the instruction \p Before is in the same
694     //  BasicBlock as Array. Make it general, for any control flow graph.
695     BasicBlock *BB = Array.getParent();
696     if (BB != Before.getParent())
697       return false;
698 
699     const DataLayout &DL = Array.getModule()->getDataLayout();
700     const unsigned int PointerSize = DL.getPointerSize();
701 
702     for (Instruction &I : *BB) {
703       if (&I == &Before)
704         break;
705 
706       if (!isa<StoreInst>(&I))
707         continue;
708 
709       auto *S = cast<StoreInst>(&I);
710       int64_t Offset = -1;
711       auto *Dst =
712           GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
713       if (Dst == &Array) {
714         int64_t Idx = Offset / PointerSize;
715         StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
716         LastAccesses[Idx] = S;
717       }
718     }
719 
720     return isFilled();
721   }
722 
723   /// Returns true if all values in StoredValues and
724   /// LastAccesses are not nullptrs.
725   bool isFilled() {
726     const unsigned NumValues = StoredValues.size();
727     for (unsigned I = 0; I < NumValues; ++I) {
728       if (!StoredValues[I] || !LastAccesses[I])
729         return false;
730     }
731 
732     return true;
733   }
734 };
735 
736 struct OpenMPOpt {
737 
738   using OptimizationRemarkGetter =
739       function_ref<OptimizationRemarkEmitter &(Function *)>;
740 
741   OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
742             OptimizationRemarkGetter OREGetter,
743             OMPInformationCache &OMPInfoCache, Attributor &A)
744       : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
745         OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
746 
747   /// Check if any remarks are enabled for openmp-opt
748   bool remarksEnabled() {
749     auto &Ctx = M.getContext();
750     return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
751   }
752 
753   /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
754   bool run(bool IsModulePass) {
755     if (SCC.empty())
756       return false;
757 
758     bool Changed = false;
759 
760     LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
761                       << " functions in a slice with "
762                       << OMPInfoCache.ModuleSlice.size() << " functions\n");
763 
764     if (IsModulePass) {
765       Changed |= runAttributor(IsModulePass);
766 
767       // Recollect uses, in case Attributor deleted any.
768       OMPInfoCache.recollectUses();
769 
770       // TODO: This should be folded into buildCustomStateMachine.
771       Changed |= rewriteDeviceCodeStateMachine();
772 
773       if (remarksEnabled())
774         analysisGlobalization();
775     } else {
776       if (PrintICVValues)
777         printICVs();
778       if (PrintOpenMPKernels)
779         printKernels();
780 
781       Changed |= runAttributor(IsModulePass);
782 
783       // Recollect uses, in case Attributor deleted any.
784       OMPInfoCache.recollectUses();
785 
786       Changed |= deleteParallelRegions();
787 
788       if (HideMemoryTransferLatency)
789         Changed |= hideMemTransfersLatency();
790       Changed |= deduplicateRuntimeCalls();
791       if (EnableParallelRegionMerging) {
792         if (mergeParallelRegions()) {
793           deduplicateRuntimeCalls();
794           Changed = true;
795         }
796       }
797     }
798 
799     return Changed;
800   }
801 
802   /// Print initial ICV values for testing.
803   /// FIXME: This should be done from the Attributor once it is added.
804   void printICVs() const {
805     InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
806                                  ICV_proc_bind};
807 
808     for (Function *F : OMPInfoCache.ModuleSlice) {
809       for (auto ICV : ICVs) {
810         auto ICVInfo = OMPInfoCache.ICVs[ICV];
811         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
812           return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
813                      << " Value: "
814                      << (ICVInfo.InitValue
815                              ? toString(ICVInfo.InitValue->getValue(), 10, true)
816                              : "IMPLEMENTATION_DEFINED");
817         };
818 
819         emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
820       }
821     }
822   }
823 
824   /// Print OpenMP GPU kernels for testing.
825   void printKernels() const {
826     for (Function *F : SCC) {
827       if (!OMPInfoCache.Kernels.count(F))
828         continue;
829 
830       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
831         return ORA << "OpenMP GPU kernel "
832                    << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
833       };
834 
835       emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
836     }
837   }
838 
839   /// Return the call if \p U is a callee use in a regular call. If \p RFI is
840   /// given it has to be the callee or a nullptr is returned.
841   static CallInst *getCallIfRegularCall(
842       Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
843     CallInst *CI = dyn_cast<CallInst>(U.getUser());
844     if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
845         (!RFI ||
846          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
847       return CI;
848     return nullptr;
849   }
850 
851   /// Return the call if \p V is a regular call. If \p RFI is given it has to be
852   /// the callee or a nullptr is returned.
853   static CallInst *getCallIfRegularCall(
854       Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
855     CallInst *CI = dyn_cast<CallInst>(&V);
856     if (CI && !CI->hasOperandBundles() &&
857         (!RFI ||
858          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
859       return CI;
860     return nullptr;
861   }
862 
863 private:
864   /// Merge parallel regions when it is safe.
865   bool mergeParallelRegions() {
866     const unsigned CallbackCalleeOperand = 2;
867     const unsigned CallbackFirstArgOperand = 3;
868     using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
869 
870     // Check if there are any __kmpc_fork_call calls to merge.
871     OMPInformationCache::RuntimeFunctionInfo &RFI =
872         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
873 
874     if (!RFI.Declaration)
875       return false;
876 
877     // Unmergable calls that prevent merging a parallel region.
878     OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
879         OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
880         OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
881     };
882 
883     bool Changed = false;
884     LoopInfo *LI = nullptr;
885     DominatorTree *DT = nullptr;
886 
887     SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
888 
889     BasicBlock *StartBB = nullptr, *EndBB = nullptr;
890     auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
891                          BasicBlock &ContinuationIP) {
892       BasicBlock *CGStartBB = CodeGenIP.getBlock();
893       BasicBlock *CGEndBB =
894           SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
895       assert(StartBB != nullptr && "StartBB should not be null");
896       CGStartBB->getTerminator()->setSuccessor(0, StartBB);
897       assert(EndBB != nullptr && "EndBB should not be null");
898       EndBB->getTerminator()->setSuccessor(0, CGEndBB);
899     };
900 
901     auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
902                       Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
903       ReplacementValue = &Inner;
904       return CodeGenIP;
905     };
906 
907     auto FiniCB = [&](InsertPointTy CodeGenIP) {};
908 
909     /// Create a sequential execution region within a merged parallel region,
910     /// encapsulated in a master construct with a barrier for synchronization.
911     auto CreateSequentialRegion = [&](Function *OuterFn,
912                                       BasicBlock *OuterPredBB,
913                                       Instruction *SeqStartI,
914                                       Instruction *SeqEndI) {
915       // Isolate the instructions of the sequential region to a separate
916       // block.
917       BasicBlock *ParentBB = SeqStartI->getParent();
918       BasicBlock *SeqEndBB =
919           SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
920       BasicBlock *SeqAfterBB =
921           SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
922       BasicBlock *SeqStartBB =
923           SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
924 
925       assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
926              "Expected a different CFG");
927       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
928       ParentBB->getTerminator()->eraseFromParent();
929 
930       auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
931                            BasicBlock &ContinuationIP) {
932         BasicBlock *CGStartBB = CodeGenIP.getBlock();
933         BasicBlock *CGEndBB =
934             SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
935         assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
936         CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
937         assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
938         SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
939       };
940       auto FiniCB = [&](InsertPointTy CodeGenIP) {};
941 
942       // Find outputs from the sequential region to outside users and
943       // broadcast their values to them.
944       for (Instruction &I : *SeqStartBB) {
945         SmallPtrSet<Instruction *, 4> OutsideUsers;
946         for (User *Usr : I.users()) {
947           Instruction &UsrI = *cast<Instruction>(Usr);
948           // Ignore outputs to LT intrinsics, code extraction for the merged
949           // parallel region will fix them.
950           if (UsrI.isLifetimeStartOrEnd())
951             continue;
952 
953           if (UsrI.getParent() != SeqStartBB)
954             OutsideUsers.insert(&UsrI);
955         }
956 
957         if (OutsideUsers.empty())
958           continue;
959 
960         // Emit an alloca in the outer region to store the broadcasted
961         // value.
962         const DataLayout &DL = M.getDataLayout();
963         AllocaInst *AllocaI = new AllocaInst(
964             I.getType(), DL.getAllocaAddrSpace(), nullptr,
965             I.getName() + ".seq.output.alloc", &OuterFn->front().front());
966 
967         // Emit a store instruction in the sequential BB to update the
968         // value.
969         new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
970 
971         // Emit a load instruction and replace the use of the output value
972         // with it.
973         for (Instruction *UsrI : OutsideUsers) {
974           LoadInst *LoadI = new LoadInst(
975               I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
976           UsrI->replaceUsesOfWith(&I, LoadI);
977         }
978       }
979 
980       OpenMPIRBuilder::LocationDescription Loc(
981           InsertPointTy(ParentBB, ParentBB->end()), DL);
982       InsertPointTy SeqAfterIP =
983           OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
984 
985       OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
986 
987       BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
988 
989       LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
990                         << "\n");
991     };
992 
993     // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
994     // contained in BB and only separated by instructions that can be
995     // redundantly executed in parallel. The block BB is split before the first
996     // call (in MergableCIs) and after the last so the entire region we merge
997     // into a single parallel region is contained in a single basic block
998     // without any other instructions. We use the OpenMPIRBuilder to outline
999     // that block and call the resulting function via __kmpc_fork_call.
1000     auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) {
1001       // TODO: Change the interface to allow single CIs expanded, e.g, to
1002       // include an outer loop.
1003       assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1004 
1005       auto Remark = [&](OptimizationRemark OR) {
1006         OR << "Parallel region merged with parallel region"
1007            << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1008         for (auto *CI : llvm::drop_begin(MergableCIs)) {
1009           OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1010           if (CI != MergableCIs.back())
1011             OR << ", ";
1012         }
1013         return OR << ".";
1014       };
1015 
1016       emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1017 
1018       Function *OriginalFn = BB->getParent();
1019       LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1020                         << " parallel regions in " << OriginalFn->getName()
1021                         << "\n");
1022 
1023       // Isolate the calls to merge in a separate block.
1024       EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1025       BasicBlock *AfterBB =
1026           SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1027       StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1028                            "omp.par.merged");
1029 
1030       assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1031       const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1032       BB->getTerminator()->eraseFromParent();
1033 
1034       // Create sequential regions for sequential instructions that are
1035       // in-between mergable parallel regions.
1036       for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1037            It != End; ++It) {
1038         Instruction *ForkCI = *It;
1039         Instruction *NextForkCI = *(It + 1);
1040 
1041         // Continue if there are not in-between instructions.
1042         if (ForkCI->getNextNode() == NextForkCI)
1043           continue;
1044 
1045         CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1046                                NextForkCI->getPrevNode());
1047       }
1048 
1049       OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1050                                                DL);
1051       IRBuilder<>::InsertPoint AllocaIP(
1052           &OriginalFn->getEntryBlock(),
1053           OriginalFn->getEntryBlock().getFirstInsertionPt());
1054       // Create the merged parallel region with default proc binding, to
1055       // avoid overriding binding settings, and without explicit cancellation.
1056       InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1057           Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1058           OMP_PROC_BIND_default, /* IsCancellable */ false);
1059       BranchInst::Create(AfterBB, AfterIP.getBlock());
1060 
1061       // Perform the actual outlining.
1062       OMPInfoCache.OMPBuilder.finalize(OriginalFn,
1063                                        /* AllowExtractorSinking */ true);
1064 
1065       Function *OutlinedFn = MergableCIs.front()->getCaller();
1066 
1067       // Replace the __kmpc_fork_call calls with direct calls to the outlined
1068       // callbacks.
1069       SmallVector<Value *, 8> Args;
1070       for (auto *CI : MergableCIs) {
1071         Value *Callee =
1072             CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts();
1073         FunctionType *FT =
1074             cast<FunctionType>(Callee->getType()->getPointerElementType());
1075         Args.clear();
1076         Args.push_back(OutlinedFn->getArg(0));
1077         Args.push_back(OutlinedFn->getArg(1));
1078         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
1079              U < E; ++U)
1080           Args.push_back(CI->getArgOperand(U));
1081 
1082         CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
1083         if (CI->getDebugLoc())
1084           NewCI->setDebugLoc(CI->getDebugLoc());
1085 
1086         // Forward parameter attributes from the callback to the callee.
1087         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
1088              U < E; ++U)
1089           for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1090             NewCI->addParamAttr(
1091                 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1092 
1093         // Emit an explicit barrier to replace the implicit fork-join barrier.
1094         if (CI != MergableCIs.back()) {
1095           // TODO: Remove barrier if the merged parallel region includes the
1096           // 'nowait' clause.
1097           OMPInfoCache.OMPBuilder.createBarrier(
1098               InsertPointTy(NewCI->getParent(),
1099                             NewCI->getNextNode()->getIterator()),
1100               OMPD_parallel);
1101         }
1102 
1103         CI->eraseFromParent();
1104       }
1105 
1106       assert(OutlinedFn != OriginalFn && "Outlining failed");
1107       CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1108       CGUpdater.reanalyzeFunction(*OriginalFn);
1109 
1110       NumOpenMPParallelRegionsMerged += MergableCIs.size();
1111 
1112       return true;
1113     };
1114 
1115     // Helper function that identifes sequences of
1116     // __kmpc_fork_call uses in a basic block.
1117     auto DetectPRsCB = [&](Use &U, Function &F) {
1118       CallInst *CI = getCallIfRegularCall(U, &RFI);
1119       BB2PRMap[CI->getParent()].insert(CI);
1120 
1121       return false;
1122     };
1123 
1124     BB2PRMap.clear();
1125     RFI.foreachUse(SCC, DetectPRsCB);
1126     SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1127     // Find mergable parallel regions within a basic block that are
1128     // safe to merge, that is any in-between instructions can safely
1129     // execute in parallel after merging.
1130     // TODO: support merging across basic-blocks.
1131     for (auto &It : BB2PRMap) {
1132       auto &CIs = It.getSecond();
1133       if (CIs.size() < 2)
1134         continue;
1135 
1136       BasicBlock *BB = It.getFirst();
1137       SmallVector<CallInst *, 4> MergableCIs;
1138 
1139       /// Returns true if the instruction is mergable, false otherwise.
1140       /// A terminator instruction is unmergable by definition since merging
1141       /// works within a BB. Instructions before the mergable region are
1142       /// mergable if they are not calls to OpenMP runtime functions that may
1143       /// set different execution parameters for subsequent parallel regions.
1144       /// Instructions in-between parallel regions are mergable if they are not
1145       /// calls to any non-intrinsic function since that may call a non-mergable
1146       /// OpenMP runtime function.
1147       auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1148         // We do not merge across BBs, hence return false (unmergable) if the
1149         // instruction is a terminator.
1150         if (I.isTerminator())
1151           return false;
1152 
1153         if (!isa<CallInst>(&I))
1154           return true;
1155 
1156         CallInst *CI = cast<CallInst>(&I);
1157         if (IsBeforeMergableRegion) {
1158           Function *CalledFunction = CI->getCalledFunction();
1159           if (!CalledFunction)
1160             return false;
1161           // Return false (unmergable) if the call before the parallel
1162           // region calls an explicit affinity (proc_bind) or number of
1163           // threads (num_threads) compiler-generated function. Those settings
1164           // may be incompatible with following parallel regions.
1165           // TODO: ICV tracking to detect compatibility.
1166           for (const auto &RFI : UnmergableCallsInfo) {
1167             if (CalledFunction == RFI.Declaration)
1168               return false;
1169           }
1170         } else {
1171           // Return false (unmergable) if there is a call instruction
1172           // in-between parallel regions when it is not an intrinsic. It
1173           // may call an unmergable OpenMP runtime function in its callpath.
1174           // TODO: Keep track of possible OpenMP calls in the callpath.
1175           if (!isa<IntrinsicInst>(CI))
1176             return false;
1177         }
1178 
1179         return true;
1180       };
1181       // Find maximal number of parallel region CIs that are safe to merge.
1182       for (auto It = BB->begin(), End = BB->end(); It != End;) {
1183         Instruction &I = *It;
1184         ++It;
1185 
1186         if (CIs.count(&I)) {
1187           MergableCIs.push_back(cast<CallInst>(&I));
1188           continue;
1189         }
1190 
1191         // Continue expanding if the instruction is mergable.
1192         if (IsMergable(I, MergableCIs.empty()))
1193           continue;
1194 
1195         // Forward the instruction iterator to skip the next parallel region
1196         // since there is an unmergable instruction which can affect it.
1197         for (; It != End; ++It) {
1198           Instruction &SkipI = *It;
1199           if (CIs.count(&SkipI)) {
1200             LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1201                               << " due to " << I << "\n");
1202             ++It;
1203             break;
1204           }
1205         }
1206 
1207         // Store mergable regions found.
1208         if (MergableCIs.size() > 1) {
1209           MergableCIsVector.push_back(MergableCIs);
1210           LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1211                             << " parallel regions in block " << BB->getName()
1212                             << " of function " << BB->getParent()->getName()
1213                             << "\n";);
1214         }
1215 
1216         MergableCIs.clear();
1217       }
1218 
1219       if (!MergableCIsVector.empty()) {
1220         Changed = true;
1221 
1222         for (auto &MergableCIs : MergableCIsVector)
1223           Merge(MergableCIs, BB);
1224         MergableCIsVector.clear();
1225       }
1226     }
1227 
1228     if (Changed) {
1229       /// Re-collect use for fork calls, emitted barrier calls, and
1230       /// any emitted master/end_master calls.
1231       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1232       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1233       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1234       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1235     }
1236 
1237     return Changed;
1238   }
1239 
1240   /// Try to delete parallel regions if possible.
1241   bool deleteParallelRegions() {
1242     const unsigned CallbackCalleeOperand = 2;
1243 
1244     OMPInformationCache::RuntimeFunctionInfo &RFI =
1245         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1246 
1247     if (!RFI.Declaration)
1248       return false;
1249 
1250     bool Changed = false;
1251     auto DeleteCallCB = [&](Use &U, Function &) {
1252       CallInst *CI = getCallIfRegularCall(U);
1253       if (!CI)
1254         return false;
1255       auto *Fn = dyn_cast<Function>(
1256           CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1257       if (!Fn)
1258         return false;
1259       if (!Fn->onlyReadsMemory())
1260         return false;
1261       if (!Fn->hasFnAttribute(Attribute::WillReturn))
1262         return false;
1263 
1264       LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1265                         << CI->getCaller()->getName() << "\n");
1266 
1267       auto Remark = [&](OptimizationRemark OR) {
1268         return OR << "Removing parallel region with no side-effects.";
1269       };
1270       emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1271 
1272       CGUpdater.removeCallSite(*CI);
1273       CI->eraseFromParent();
1274       Changed = true;
1275       ++NumOpenMPParallelRegionsDeleted;
1276       return true;
1277     };
1278 
1279     RFI.foreachUse(SCC, DeleteCallCB);
1280 
1281     return Changed;
1282   }
1283 
1284   /// Try to eliminate runtime calls by reusing existing ones.
1285   bool deduplicateRuntimeCalls() {
1286     bool Changed = false;
1287 
1288     RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1289         OMPRTL_omp_get_num_threads,
1290         OMPRTL_omp_in_parallel,
1291         OMPRTL_omp_get_cancellation,
1292         OMPRTL_omp_get_thread_limit,
1293         OMPRTL_omp_get_supported_active_levels,
1294         OMPRTL_omp_get_level,
1295         OMPRTL_omp_get_ancestor_thread_num,
1296         OMPRTL_omp_get_team_size,
1297         OMPRTL_omp_get_active_level,
1298         OMPRTL_omp_in_final,
1299         OMPRTL_omp_get_proc_bind,
1300         OMPRTL_omp_get_num_places,
1301         OMPRTL_omp_get_num_procs,
1302         OMPRTL_omp_get_place_num,
1303         OMPRTL_omp_get_partition_num_places,
1304         OMPRTL_omp_get_partition_place_nums};
1305 
1306     // Global-tid is handled separately.
1307     SmallSetVector<Value *, 16> GTIdArgs;
1308     collectGlobalThreadIdArguments(GTIdArgs);
1309     LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1310                       << " global thread ID arguments\n");
1311 
1312     for (Function *F : SCC) {
1313       for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1314         Changed |= deduplicateRuntimeCalls(
1315             *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1316 
1317       // __kmpc_global_thread_num is special as we can replace it with an
1318       // argument in enough cases to make it worth trying.
1319       Value *GTIdArg = nullptr;
1320       for (Argument &Arg : F->args())
1321         if (GTIdArgs.count(&Arg)) {
1322           GTIdArg = &Arg;
1323           break;
1324         }
1325       Changed |= deduplicateRuntimeCalls(
1326           *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1327     }
1328 
1329     return Changed;
1330   }
1331 
1332   /// Tries to hide the latency of runtime calls that involve host to
1333   /// device memory transfers by splitting them into their "issue" and "wait"
1334   /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1335   /// moved downards as much as possible. The "issue" issues the memory transfer
1336   /// asynchronously, returning a handle. The "wait" waits in the returned
1337   /// handle for the memory transfer to finish.
1338   bool hideMemTransfersLatency() {
1339     auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1340     bool Changed = false;
1341     auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1342       auto *RTCall = getCallIfRegularCall(U, &RFI);
1343       if (!RTCall)
1344         return false;
1345 
1346       OffloadArray OffloadArrays[3];
1347       if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1348         return false;
1349 
1350       LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1351 
1352       // TODO: Check if can be moved upwards.
1353       bool WasSplit = false;
1354       Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1355       if (WaitMovementPoint)
1356         WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1357 
1358       Changed |= WasSplit;
1359       return WasSplit;
1360     };
1361     RFI.foreachUse(SCC, SplitMemTransfers);
1362 
1363     return Changed;
1364   }
1365 
1366   void analysisGlobalization() {
1367     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1368 
1369     auto CheckGlobalization = [&](Use &U, Function &Decl) {
1370       if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1371         auto Remark = [&](OptimizationRemarkMissed ORM) {
1372           return ORM
1373                  << "Found thread data sharing on the GPU. "
1374                  << "Expect degraded performance due to data globalization.";
1375         };
1376         emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1377       }
1378 
1379       return false;
1380     };
1381 
1382     RFI.foreachUse(SCC, CheckGlobalization);
1383   }
1384 
1385   /// Maps the values stored in the offload arrays passed as arguments to
1386   /// \p RuntimeCall into the offload arrays in \p OAs.
1387   bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1388                                 MutableArrayRef<OffloadArray> OAs) {
1389     assert(OAs.size() == 3 && "Need space for three offload arrays!");
1390 
1391     // A runtime call that involves memory offloading looks something like:
1392     // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1393     //   i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1394     // ...)
1395     // So, the idea is to access the allocas that allocate space for these
1396     // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1397     // Therefore:
1398     // i8** %offload_baseptrs.
1399     Value *BasePtrsArg =
1400         RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1401     // i8** %offload_ptrs.
1402     Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1403     // i8** %offload_sizes.
1404     Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1405 
1406     // Get values stored in **offload_baseptrs.
1407     auto *V = getUnderlyingObject(BasePtrsArg);
1408     if (!isa<AllocaInst>(V))
1409       return false;
1410     auto *BasePtrsArray = cast<AllocaInst>(V);
1411     if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1412       return false;
1413 
1414     // Get values stored in **offload_baseptrs.
1415     V = getUnderlyingObject(PtrsArg);
1416     if (!isa<AllocaInst>(V))
1417       return false;
1418     auto *PtrsArray = cast<AllocaInst>(V);
1419     if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1420       return false;
1421 
1422     // Get values stored in **offload_sizes.
1423     V = getUnderlyingObject(SizesArg);
1424     // If it's a [constant] global array don't analyze it.
1425     if (isa<GlobalValue>(V))
1426       return isa<Constant>(V);
1427     if (!isa<AllocaInst>(V))
1428       return false;
1429 
1430     auto *SizesArray = cast<AllocaInst>(V);
1431     if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1432       return false;
1433 
1434     return true;
1435   }
1436 
1437   /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1438   /// For now this is a way to test that the function getValuesInOffloadArrays
1439   /// is working properly.
1440   /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1441   void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1442     assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1443 
1444     LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1445     std::string ValuesStr;
1446     raw_string_ostream Printer(ValuesStr);
1447     std::string Separator = " --- ";
1448 
1449     for (auto *BP : OAs[0].StoredValues) {
1450       BP->print(Printer);
1451       Printer << Separator;
1452     }
1453     LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
1454     ValuesStr.clear();
1455 
1456     for (auto *P : OAs[1].StoredValues) {
1457       P->print(Printer);
1458       Printer << Separator;
1459     }
1460     LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
1461     ValuesStr.clear();
1462 
1463     for (auto *S : OAs[2].StoredValues) {
1464       S->print(Printer);
1465       Printer << Separator;
1466     }
1467     LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
1468   }
1469 
1470   /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1471   /// moved. Returns nullptr if the movement is not possible, or not worth it.
1472   Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1473     // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1474     //  Make it traverse the CFG.
1475 
1476     Instruction *CurrentI = &RuntimeCall;
1477     bool IsWorthIt = false;
1478     while ((CurrentI = CurrentI->getNextNode())) {
1479 
1480       // TODO: Once we detect the regions to be offloaded we should use the
1481       //  alias analysis manager to check if CurrentI may modify one of
1482       //  the offloaded regions.
1483       if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1484         if (IsWorthIt)
1485           return CurrentI;
1486 
1487         return nullptr;
1488       }
1489 
1490       // FIXME: For now if we move it over anything without side effect
1491       //  is worth it.
1492       IsWorthIt = true;
1493     }
1494 
1495     // Return end of BasicBlock.
1496     return RuntimeCall.getParent()->getTerminator();
1497   }
1498 
1499   /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1500   bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1501                                Instruction &WaitMovementPoint) {
1502     // Create stack allocated handle (__tgt_async_info) at the beginning of the
1503     // function. Used for storing information of the async transfer, allowing to
1504     // wait on it later.
1505     auto &IRBuilder = OMPInfoCache.OMPBuilder;
1506     auto *F = RuntimeCall.getCaller();
1507     Instruction *FirstInst = &(F->getEntryBlock().front());
1508     AllocaInst *Handle = new AllocaInst(
1509         IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst);
1510 
1511     // Add "issue" runtime call declaration:
1512     // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1513     //   i8**, i8**, i64*, i64*)
1514     FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1515         M, OMPRTL___tgt_target_data_begin_mapper_issue);
1516 
1517     // Change RuntimeCall call site for its asynchronous version.
1518     SmallVector<Value *, 16> Args;
1519     for (auto &Arg : RuntimeCall.args())
1520       Args.push_back(Arg.get());
1521     Args.push_back(Handle);
1522 
1523     CallInst *IssueCallsite =
1524         CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
1525     RuntimeCall.eraseFromParent();
1526 
1527     // Add "wait" runtime call declaration:
1528     // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1529     FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1530         M, OMPRTL___tgt_target_data_begin_mapper_wait);
1531 
1532     Value *WaitParams[2] = {
1533         IssueCallsite->getArgOperand(
1534             OffloadArray::DeviceIDArgNum), // device_id.
1535         Handle                             // handle to wait on.
1536     };
1537     CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
1538 
1539     return true;
1540   }
1541 
1542   static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1543                                     bool GlobalOnly, bool &SingleChoice) {
1544     if (CurrentIdent == NextIdent)
1545       return CurrentIdent;
1546 
1547     // TODO: Figure out how to actually combine multiple debug locations. For
1548     //       now we just keep an existing one if there is a single choice.
1549     if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1550       SingleChoice = !CurrentIdent;
1551       return NextIdent;
1552     }
1553     return nullptr;
1554   }
1555 
1556   /// Return an `struct ident_t*` value that represents the ones used in the
1557   /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1558   /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1559   /// return value we create one from scratch. We also do not yet combine
1560   /// information, e.g., the source locations, see combinedIdentStruct.
1561   Value *
1562   getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1563                                  Function &F, bool GlobalOnly) {
1564     bool SingleChoice = true;
1565     Value *Ident = nullptr;
1566     auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1567       CallInst *CI = getCallIfRegularCall(U, &RFI);
1568       if (!CI || &F != &Caller)
1569         return false;
1570       Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1571                                   /* GlobalOnly */ true, SingleChoice);
1572       return false;
1573     };
1574     RFI.foreachUse(SCC, CombineIdentStruct);
1575 
1576     if (!Ident || !SingleChoice) {
1577       // The IRBuilder uses the insertion block to get to the module, this is
1578       // unfortunate but we work around it for now.
1579       if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1580         OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1581             &F.getEntryBlock(), F.getEntryBlock().begin()));
1582       // Create a fallback location if non was found.
1583       // TODO: Use the debug locations of the calls instead.
1584       Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr();
1585       Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc);
1586     }
1587     return Ident;
1588   }
1589 
1590   /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1591   /// \p ReplVal if given.
1592   bool deduplicateRuntimeCalls(Function &F,
1593                                OMPInformationCache::RuntimeFunctionInfo &RFI,
1594                                Value *ReplVal = nullptr) {
1595     auto *UV = RFI.getUseVector(F);
1596     if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1597       return false;
1598 
1599     LLVM_DEBUG(
1600         dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1601                << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1602 
1603     assert((!ReplVal || (isa<Argument>(ReplVal) &&
1604                          cast<Argument>(ReplVal)->getParent() == &F)) &&
1605            "Unexpected replacement value!");
1606 
1607     // TODO: Use dominance to find a good position instead.
1608     auto CanBeMoved = [this](CallBase &CB) {
1609       unsigned NumArgs = CB.getNumArgOperands();
1610       if (NumArgs == 0)
1611         return true;
1612       if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1613         return false;
1614       for (unsigned u = 1; u < NumArgs; ++u)
1615         if (isa<Instruction>(CB.getArgOperand(u)))
1616           return false;
1617       return true;
1618     };
1619 
1620     if (!ReplVal) {
1621       for (Use *U : *UV)
1622         if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1623           if (!CanBeMoved(*CI))
1624             continue;
1625 
1626           // If the function is a kernel, dedup will move
1627           // the runtime call right after the kernel init callsite. Otherwise,
1628           // it will move it to the beginning of the caller function.
1629           if (isKernel(F)) {
1630             auto &KernelInitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
1631             auto *KernelInitUV = KernelInitRFI.getUseVector(F);
1632 
1633             if (KernelInitUV->empty())
1634               continue;
1635 
1636             assert(KernelInitUV->size() == 1 &&
1637                    "Expected a single __kmpc_target_init in kernel\n");
1638 
1639             CallInst *KernelInitCI =
1640                 getCallIfRegularCall(*KernelInitUV->front(), &KernelInitRFI);
1641             assert(KernelInitCI &&
1642                    "Expected a call to __kmpc_target_init in kernel\n");
1643 
1644             CI->moveAfter(KernelInitCI);
1645           } else
1646             CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
1647           ReplVal = CI;
1648           break;
1649         }
1650       if (!ReplVal)
1651         return false;
1652     }
1653 
1654     // If we use a call as a replacement value we need to make sure the ident is
1655     // valid at the new location. For now we just pick a global one, either
1656     // existing and used by one of the calls, or created from scratch.
1657     if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1658       if (!CI->arg_empty() &&
1659           CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1660         Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1661                                                       /* GlobalOnly */ true);
1662         CI->setArgOperand(0, Ident);
1663       }
1664     }
1665 
1666     bool Changed = false;
1667     auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1668       CallInst *CI = getCallIfRegularCall(U, &RFI);
1669       if (!CI || CI == ReplVal || &F != &Caller)
1670         return false;
1671       assert(CI->getCaller() == &F && "Unexpected call!");
1672 
1673       auto Remark = [&](OptimizationRemark OR) {
1674         return OR << "OpenMP runtime call "
1675                   << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1676       };
1677       if (CI->getDebugLoc())
1678         emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1679       else
1680         emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1681 
1682       CGUpdater.removeCallSite(*CI);
1683       CI->replaceAllUsesWith(ReplVal);
1684       CI->eraseFromParent();
1685       ++NumOpenMPRuntimeCallsDeduplicated;
1686       Changed = true;
1687       return true;
1688     };
1689     RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1690 
1691     return Changed;
1692   }
1693 
1694   /// Collect arguments that represent the global thread id in \p GTIdArgs.
1695   void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1696     // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1697     //       initialization. We could define an AbstractAttribute instead and
1698     //       run the Attributor here once it can be run as an SCC pass.
1699 
1700     // Helper to check the argument \p ArgNo at all call sites of \p F for
1701     // a GTId.
1702     auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1703       if (!F.hasLocalLinkage())
1704         return false;
1705       for (Use &U : F.uses()) {
1706         if (CallInst *CI = getCallIfRegularCall(U)) {
1707           Value *ArgOp = CI->getArgOperand(ArgNo);
1708           if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1709               getCallIfRegularCall(
1710                   *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1711             continue;
1712         }
1713         return false;
1714       }
1715       return true;
1716     };
1717 
1718     // Helper to identify uses of a GTId as GTId arguments.
1719     auto AddUserArgs = [&](Value &GTId) {
1720       for (Use &U : GTId.uses())
1721         if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1722           if (CI->isArgOperand(&U))
1723             if (Function *Callee = CI->getCalledFunction())
1724               if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1725                 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1726     };
1727 
1728     // The argument users of __kmpc_global_thread_num calls are GTIds.
1729     OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1730         OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1731 
1732     GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1733       if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1734         AddUserArgs(*CI);
1735       return false;
1736     });
1737 
1738     // Transitively search for more arguments by looking at the users of the
1739     // ones we know already. During the search the GTIdArgs vector is extended
1740     // so we cannot cache the size nor can we use a range based for.
1741     for (unsigned u = 0; u < GTIdArgs.size(); ++u)
1742       AddUserArgs(*GTIdArgs[u]);
1743   }
1744 
1745   /// Kernel (=GPU) optimizations and utility functions
1746   ///
1747   ///{{
1748 
1749   /// Check if \p F is a kernel, hence entry point for target offloading.
1750   bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
1751 
1752   /// Cache to remember the unique kernel for a function.
1753   DenseMap<Function *, Optional<Kernel>> UniqueKernelMap;
1754 
1755   /// Find the unique kernel that will execute \p F, if any.
1756   Kernel getUniqueKernelFor(Function &F);
1757 
1758   /// Find the unique kernel that will execute \p I, if any.
1759   Kernel getUniqueKernelFor(Instruction &I) {
1760     return getUniqueKernelFor(*I.getFunction());
1761   }
1762 
1763   /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1764   /// the cases we can avoid taking the address of a function.
1765   bool rewriteDeviceCodeStateMachine();
1766 
1767   ///
1768   ///}}
1769 
1770   /// Emit a remark generically
1771   ///
1772   /// This template function can be used to generically emit a remark. The
1773   /// RemarkKind should be one of the following:
1774   ///   - OptimizationRemark to indicate a successful optimization attempt
1775   ///   - OptimizationRemarkMissed to report a failed optimization attempt
1776   ///   - OptimizationRemarkAnalysis to provide additional information about an
1777   ///     optimization attempt
1778   ///
1779   /// The remark is built using a callback function provided by the caller that
1780   /// takes a RemarkKind as input and returns a RemarkKind.
1781   template <typename RemarkKind, typename RemarkCallBack>
1782   void emitRemark(Instruction *I, StringRef RemarkName,
1783                   RemarkCallBack &&RemarkCB) const {
1784     Function *F = I->getParent()->getParent();
1785     auto &ORE = OREGetter(F);
1786 
1787     if (RemarkName.startswith("OMP"))
1788       ORE.emit([&]() {
1789         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
1790                << " [" << RemarkName << "]";
1791       });
1792     else
1793       ORE.emit(
1794           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
1795   }
1796 
1797   /// Emit a remark on a function.
1798   template <typename RemarkKind, typename RemarkCallBack>
1799   void emitRemark(Function *F, StringRef RemarkName,
1800                   RemarkCallBack &&RemarkCB) const {
1801     auto &ORE = OREGetter(F);
1802 
1803     if (RemarkName.startswith("OMP"))
1804       ORE.emit([&]() {
1805         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
1806                << " [" << RemarkName << "]";
1807       });
1808     else
1809       ORE.emit(
1810           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
1811   }
1812 
1813   /// RAII struct to temporarily change an RTL function's linkage to external.
1814   /// This prevents it from being mistakenly removed by other optimizations.
1815   struct ExternalizationRAII {
1816     ExternalizationRAII(OMPInformationCache &OMPInfoCache,
1817                         RuntimeFunction RFKind)
1818         : Declaration(OMPInfoCache.RFIs[RFKind].Declaration) {
1819       if (!Declaration)
1820         return;
1821 
1822       LinkageType = Declaration->getLinkage();
1823       Declaration->setLinkage(GlobalValue::ExternalLinkage);
1824     }
1825 
1826     ~ExternalizationRAII() {
1827       if (!Declaration)
1828         return;
1829 
1830       Declaration->setLinkage(LinkageType);
1831     }
1832 
1833     Function *Declaration;
1834     GlobalValue::LinkageTypes LinkageType;
1835   };
1836 
1837   /// The underlying module.
1838   Module &M;
1839 
1840   /// The SCC we are operating on.
1841   SmallVectorImpl<Function *> &SCC;
1842 
1843   /// Callback to update the call graph, the first argument is a removed call,
1844   /// the second an optional replacement call.
1845   CallGraphUpdater &CGUpdater;
1846 
1847   /// Callback to get an OptimizationRemarkEmitter from a Function *
1848   OptimizationRemarkGetter OREGetter;
1849 
1850   /// OpenMP-specific information cache. Also Used for Attributor runs.
1851   OMPInformationCache &OMPInfoCache;
1852 
1853   /// Attributor instance.
1854   Attributor &A;
1855 
1856   /// Helper function to run Attributor on SCC.
1857   bool runAttributor(bool IsModulePass) {
1858     if (SCC.empty())
1859       return false;
1860 
1861     // Temporarily make these function have external linkage so the Attributor
1862     // doesn't remove them when we try to look them up later.
1863     ExternalizationRAII Parallel(OMPInfoCache, OMPRTL___kmpc_kernel_parallel);
1864     ExternalizationRAII EndParallel(OMPInfoCache,
1865                                     OMPRTL___kmpc_kernel_end_parallel);
1866     ExternalizationRAII BarrierSPMD(OMPInfoCache,
1867                                     OMPRTL___kmpc_barrier_simple_spmd);
1868 
1869     registerAAs(IsModulePass);
1870 
1871     ChangeStatus Changed = A.run();
1872 
1873     LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
1874                       << " functions, result: " << Changed << ".\n");
1875 
1876     return Changed == ChangeStatus::CHANGED;
1877   }
1878 
1879   void registerFoldRuntimeCall(RuntimeFunction RF);
1880 
1881   /// Populate the Attributor with abstract attribute opportunities in the
1882   /// function.
1883   void registerAAs(bool IsModulePass);
1884 };
1885 
1886 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
1887   if (!OMPInfoCache.ModuleSlice.count(&F))
1888     return nullptr;
1889 
1890   // Use a scope to keep the lifetime of the CachedKernel short.
1891   {
1892     Optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
1893     if (CachedKernel)
1894       return *CachedKernel;
1895 
1896     // TODO: We should use an AA to create an (optimistic and callback
1897     //       call-aware) call graph. For now we stick to simple patterns that
1898     //       are less powerful, basically the worst fixpoint.
1899     if (isKernel(F)) {
1900       CachedKernel = Kernel(&F);
1901       return *CachedKernel;
1902     }
1903 
1904     CachedKernel = nullptr;
1905     if (!F.hasLocalLinkage()) {
1906 
1907       // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
1908       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1909         return ORA << "Potentially unknown OpenMP target region caller.";
1910       };
1911       emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
1912 
1913       return nullptr;
1914     }
1915   }
1916 
1917   auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
1918     if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
1919       // Allow use in equality comparisons.
1920       if (Cmp->isEquality())
1921         return getUniqueKernelFor(*Cmp);
1922       return nullptr;
1923     }
1924     if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
1925       // Allow direct calls.
1926       if (CB->isCallee(&U))
1927         return getUniqueKernelFor(*CB);
1928 
1929       OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1930           OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1931       // Allow the use in __kmpc_parallel_51 calls.
1932       if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
1933         return getUniqueKernelFor(*CB);
1934       return nullptr;
1935     }
1936     // Disallow every other use.
1937     return nullptr;
1938   };
1939 
1940   // TODO: In the future we want to track more than just a unique kernel.
1941   SmallPtrSet<Kernel, 2> PotentialKernels;
1942   OMPInformationCache::foreachUse(F, [&](const Use &U) {
1943     PotentialKernels.insert(GetUniqueKernelForUse(U));
1944   });
1945 
1946   Kernel K = nullptr;
1947   if (PotentialKernels.size() == 1)
1948     K = *PotentialKernels.begin();
1949 
1950   // Cache the result.
1951   UniqueKernelMap[&F] = K;
1952 
1953   return K;
1954 }
1955 
1956 bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
1957   OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1958       OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1959 
1960   bool Changed = false;
1961   if (!KernelParallelRFI)
1962     return Changed;
1963 
1964   // If we have disabled state machine changes, exit
1965   if (DisableOpenMPOptStateMachineRewrite)
1966     return Changed;
1967 
1968   for (Function *F : SCC) {
1969 
1970     // Check if the function is a use in a __kmpc_parallel_51 call at
1971     // all.
1972     bool UnknownUse = false;
1973     bool KernelParallelUse = false;
1974     unsigned NumDirectCalls = 0;
1975 
1976     SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
1977     OMPInformationCache::foreachUse(*F, [&](Use &U) {
1978       if (auto *CB = dyn_cast<CallBase>(U.getUser()))
1979         if (CB->isCallee(&U)) {
1980           ++NumDirectCalls;
1981           return;
1982         }
1983 
1984       if (isa<ICmpInst>(U.getUser())) {
1985         ToBeReplacedStateMachineUses.push_back(&U);
1986         return;
1987       }
1988 
1989       // Find wrapper functions that represent parallel kernels.
1990       CallInst *CI =
1991           OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
1992       const unsigned int WrapperFunctionArgNo = 6;
1993       if (!KernelParallelUse && CI &&
1994           CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
1995         KernelParallelUse = true;
1996         ToBeReplacedStateMachineUses.push_back(&U);
1997         return;
1998       }
1999       UnknownUse = true;
2000     });
2001 
2002     // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2003     // use.
2004     if (!KernelParallelUse)
2005       continue;
2006 
2007     // If this ever hits, we should investigate.
2008     // TODO: Checking the number of uses is not a necessary restriction and
2009     // should be lifted.
2010     if (UnknownUse || NumDirectCalls != 1 ||
2011         ToBeReplacedStateMachineUses.size() > 2) {
2012       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2013         return ORA << "Parallel region is used in "
2014                    << (UnknownUse ? "unknown" : "unexpected")
2015                    << " ways. Will not attempt to rewrite the state machine.";
2016       };
2017       emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2018       continue;
2019     }
2020 
2021     // Even if we have __kmpc_parallel_51 calls, we (for now) give
2022     // up if the function is not called from a unique kernel.
2023     Kernel K = getUniqueKernelFor(*F);
2024     if (!K) {
2025       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2026         return ORA << "Parallel region is not called from a unique kernel. "
2027                       "Will not attempt to rewrite the state machine.";
2028       };
2029       emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2030       continue;
2031     }
2032 
2033     // We now know F is a parallel body function called only from the kernel K.
2034     // We also identified the state machine uses in which we replace the
2035     // function pointer by a new global symbol for identification purposes. This
2036     // ensures only direct calls to the function are left.
2037 
2038     Module &M = *F->getParent();
2039     Type *Int8Ty = Type::getInt8Ty(M.getContext());
2040 
2041     auto *ID = new GlobalVariable(
2042         M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2043         UndefValue::get(Int8Ty), F->getName() + ".ID");
2044 
2045     for (Use *U : ToBeReplacedStateMachineUses)
2046       U->set(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
2047           ID, U->get()->getType()));
2048 
2049     ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2050 
2051     Changed = true;
2052   }
2053 
2054   return Changed;
2055 }
2056 
2057 /// Abstract Attribute for tracking ICV values.
2058 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2059   using Base = StateWrapper<BooleanState, AbstractAttribute>;
2060   AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2061 
2062   void initialize(Attributor &A) override {
2063     Function *F = getAnchorScope();
2064     if (!F || !A.isFunctionIPOAmendable(*F))
2065       indicatePessimisticFixpoint();
2066   }
2067 
2068   /// Returns true if value is assumed to be tracked.
2069   bool isAssumedTracked() const { return getAssumed(); }
2070 
2071   /// Returns true if value is known to be tracked.
2072   bool isKnownTracked() const { return getAssumed(); }
2073 
2074   /// Create an abstract attribute biew for the position \p IRP.
2075   static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2076 
2077   /// Return the value with which \p I can be replaced for specific \p ICV.
2078   virtual Optional<Value *> getReplacementValue(InternalControlVar ICV,
2079                                                 const Instruction *I,
2080                                                 Attributor &A) const {
2081     return None;
2082   }
2083 
2084   /// Return an assumed unique ICV value if a single candidate is found. If
2085   /// there cannot be one, return a nullptr. If it is not clear yet, return the
2086   /// Optional::NoneType.
2087   virtual Optional<Value *>
2088   getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2089 
2090   // Currently only nthreads is being tracked.
2091   // this array will only grow with time.
2092   InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2093 
2094   /// See AbstractAttribute::getName()
2095   const std::string getName() const override { return "AAICVTracker"; }
2096 
2097   /// See AbstractAttribute::getIdAddr()
2098   const char *getIdAddr() const override { return &ID; }
2099 
2100   /// This function should return true if the type of the \p AA is AAICVTracker
2101   static bool classof(const AbstractAttribute *AA) {
2102     return (AA->getIdAddr() == &ID);
2103   }
2104 
2105   static const char ID;
2106 };
2107 
2108 struct AAICVTrackerFunction : public AAICVTracker {
2109   AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2110       : AAICVTracker(IRP, A) {}
2111 
2112   // FIXME: come up with better string.
2113   const std::string getAsStr() const override { return "ICVTrackerFunction"; }
2114 
2115   // FIXME: come up with some stats.
2116   void trackStatistics() const override {}
2117 
2118   /// We don't manifest anything for this AA.
2119   ChangeStatus manifest(Attributor &A) override {
2120     return ChangeStatus::UNCHANGED;
2121   }
2122 
2123   // Map of ICV to their values at specific program point.
2124   EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
2125                   InternalControlVar::ICV___last>
2126       ICVReplacementValuesMap;
2127 
2128   ChangeStatus updateImpl(Attributor &A) override {
2129     ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2130 
2131     Function *F = getAnchorScope();
2132 
2133     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2134 
2135     for (InternalControlVar ICV : TrackableICVs) {
2136       auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2137 
2138       auto &ValuesMap = ICVReplacementValuesMap[ICV];
2139       auto TrackValues = [&](Use &U, Function &) {
2140         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2141         if (!CI)
2142           return false;
2143 
2144         // FIXME: handle setters with more that 1 arguments.
2145         /// Track new value.
2146         if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2147           HasChanged = ChangeStatus::CHANGED;
2148 
2149         return false;
2150       };
2151 
2152       auto CallCheck = [&](Instruction &I) {
2153         Optional<Value *> ReplVal = getValueForCall(A, &I, ICV);
2154         if (ReplVal.hasValue() &&
2155             ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2156           HasChanged = ChangeStatus::CHANGED;
2157 
2158         return true;
2159       };
2160 
2161       // Track all changes of an ICV.
2162       SetterRFI.foreachUse(TrackValues, F);
2163 
2164       bool UsedAssumedInformation = false;
2165       A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2166                                 UsedAssumedInformation,
2167                                 /* CheckBBLivenessOnly */ true);
2168 
2169       /// TODO: Figure out a way to avoid adding entry in
2170       /// ICVReplacementValuesMap
2171       Instruction *Entry = &F->getEntryBlock().front();
2172       if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2173         ValuesMap.insert(std::make_pair(Entry, nullptr));
2174     }
2175 
2176     return HasChanged;
2177   }
2178 
2179   /// Hepler to check if \p I is a call and get the value for it if it is
2180   /// unique.
2181   Optional<Value *> getValueForCall(Attributor &A, const Instruction *I,
2182                                     InternalControlVar &ICV) const {
2183 
2184     const auto *CB = dyn_cast<CallBase>(I);
2185     if (!CB || CB->hasFnAttr("no_openmp") ||
2186         CB->hasFnAttr("no_openmp_routines"))
2187       return None;
2188 
2189     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2190     auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2191     auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2192     Function *CalledFunction = CB->getCalledFunction();
2193 
2194     // Indirect call, assume ICV changes.
2195     if (CalledFunction == nullptr)
2196       return nullptr;
2197     if (CalledFunction == GetterRFI.Declaration)
2198       return None;
2199     if (CalledFunction == SetterRFI.Declaration) {
2200       if (ICVReplacementValuesMap[ICV].count(I))
2201         return ICVReplacementValuesMap[ICV].lookup(I);
2202 
2203       return nullptr;
2204     }
2205 
2206     // Since we don't know, assume it changes the ICV.
2207     if (CalledFunction->isDeclaration())
2208       return nullptr;
2209 
2210     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2211         *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2212 
2213     if (ICVTrackingAA.isAssumedTracked())
2214       return ICVTrackingAA.getUniqueReplacementValue(ICV);
2215 
2216     // If we don't know, assume it changes.
2217     return nullptr;
2218   }
2219 
2220   // We don't check unique value for a function, so return None.
2221   Optional<Value *>
2222   getUniqueReplacementValue(InternalControlVar ICV) const override {
2223     return None;
2224   }
2225 
2226   /// Return the value with which \p I can be replaced for specific \p ICV.
2227   Optional<Value *> getReplacementValue(InternalControlVar ICV,
2228                                         const Instruction *I,
2229                                         Attributor &A) const override {
2230     const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2231     if (ValuesMap.count(I))
2232       return ValuesMap.lookup(I);
2233 
2234     SmallVector<const Instruction *, 16> Worklist;
2235     SmallPtrSet<const Instruction *, 16> Visited;
2236     Worklist.push_back(I);
2237 
2238     Optional<Value *> ReplVal;
2239 
2240     while (!Worklist.empty()) {
2241       const Instruction *CurrInst = Worklist.pop_back_val();
2242       if (!Visited.insert(CurrInst).second)
2243         continue;
2244 
2245       const BasicBlock *CurrBB = CurrInst->getParent();
2246 
2247       // Go up and look for all potential setters/calls that might change the
2248       // ICV.
2249       while ((CurrInst = CurrInst->getPrevNode())) {
2250         if (ValuesMap.count(CurrInst)) {
2251           Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2252           // Unknown value, track new.
2253           if (!ReplVal.hasValue()) {
2254             ReplVal = NewReplVal;
2255             break;
2256           }
2257 
2258           // If we found a new value, we can't know the icv value anymore.
2259           if (NewReplVal.hasValue())
2260             if (ReplVal != NewReplVal)
2261               return nullptr;
2262 
2263           break;
2264         }
2265 
2266         Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV);
2267         if (!NewReplVal.hasValue())
2268           continue;
2269 
2270         // Unknown value, track new.
2271         if (!ReplVal.hasValue()) {
2272           ReplVal = NewReplVal;
2273           break;
2274         }
2275 
2276         // if (NewReplVal.hasValue())
2277         // We found a new value, we can't know the icv value anymore.
2278         if (ReplVal != NewReplVal)
2279           return nullptr;
2280       }
2281 
2282       // If we are in the same BB and we have a value, we are done.
2283       if (CurrBB == I->getParent() && ReplVal.hasValue())
2284         return ReplVal;
2285 
2286       // Go through all predecessors and add terminators for analysis.
2287       for (const BasicBlock *Pred : predecessors(CurrBB))
2288         if (const Instruction *Terminator = Pred->getTerminator())
2289           Worklist.push_back(Terminator);
2290     }
2291 
2292     return ReplVal;
2293   }
2294 };
2295 
2296 struct AAICVTrackerFunctionReturned : AAICVTracker {
2297   AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2298       : AAICVTracker(IRP, A) {}
2299 
2300   // FIXME: come up with better string.
2301   const std::string getAsStr() const override {
2302     return "ICVTrackerFunctionReturned";
2303   }
2304 
2305   // FIXME: come up with some stats.
2306   void trackStatistics() const override {}
2307 
2308   /// We don't manifest anything for this AA.
2309   ChangeStatus manifest(Attributor &A) override {
2310     return ChangeStatus::UNCHANGED;
2311   }
2312 
2313   // Map of ICV to their values at specific program point.
2314   EnumeratedArray<Optional<Value *>, InternalControlVar,
2315                   InternalControlVar::ICV___last>
2316       ICVReplacementValuesMap;
2317 
2318   /// Return the value with which \p I can be replaced for specific \p ICV.
2319   Optional<Value *>
2320   getUniqueReplacementValue(InternalControlVar ICV) const override {
2321     return ICVReplacementValuesMap[ICV];
2322   }
2323 
2324   ChangeStatus updateImpl(Attributor &A) override {
2325     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2326     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2327         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2328 
2329     if (!ICVTrackingAA.isAssumedTracked())
2330       return indicatePessimisticFixpoint();
2331 
2332     for (InternalControlVar ICV : TrackableICVs) {
2333       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2334       Optional<Value *> UniqueICVValue;
2335 
2336       auto CheckReturnInst = [&](Instruction &I) {
2337         Optional<Value *> NewReplVal =
2338             ICVTrackingAA.getReplacementValue(ICV, &I, A);
2339 
2340         // If we found a second ICV value there is no unique returned value.
2341         if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal)
2342           return false;
2343 
2344         UniqueICVValue = NewReplVal;
2345 
2346         return true;
2347       };
2348 
2349       bool UsedAssumedInformation = false;
2350       if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2351                                      UsedAssumedInformation,
2352                                      /* CheckBBLivenessOnly */ true))
2353         UniqueICVValue = nullptr;
2354 
2355       if (UniqueICVValue == ReplVal)
2356         continue;
2357 
2358       ReplVal = UniqueICVValue;
2359       Changed = ChangeStatus::CHANGED;
2360     }
2361 
2362     return Changed;
2363   }
2364 };
2365 
2366 struct AAICVTrackerCallSite : AAICVTracker {
2367   AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2368       : AAICVTracker(IRP, A) {}
2369 
2370   void initialize(Attributor &A) override {
2371     Function *F = getAnchorScope();
2372     if (!F || !A.isFunctionIPOAmendable(*F))
2373       indicatePessimisticFixpoint();
2374 
2375     // We only initialize this AA for getters, so we need to know which ICV it
2376     // gets.
2377     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2378     for (InternalControlVar ICV : TrackableICVs) {
2379       auto ICVInfo = OMPInfoCache.ICVs[ICV];
2380       auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2381       if (Getter.Declaration == getAssociatedFunction()) {
2382         AssociatedICV = ICVInfo.Kind;
2383         return;
2384       }
2385     }
2386 
2387     /// Unknown ICV.
2388     indicatePessimisticFixpoint();
2389   }
2390 
2391   ChangeStatus manifest(Attributor &A) override {
2392     if (!ReplVal.hasValue() || !ReplVal.getValue())
2393       return ChangeStatus::UNCHANGED;
2394 
2395     A.changeValueAfterManifest(*getCtxI(), **ReplVal);
2396     A.deleteAfterManifest(*getCtxI());
2397 
2398     return ChangeStatus::CHANGED;
2399   }
2400 
2401   // FIXME: come up with better string.
2402   const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
2403 
2404   // FIXME: come up with some stats.
2405   void trackStatistics() const override {}
2406 
2407   InternalControlVar AssociatedICV;
2408   Optional<Value *> ReplVal;
2409 
2410   ChangeStatus updateImpl(Attributor &A) override {
2411     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2412         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2413 
2414     // We don't have any information, so we assume it changes the ICV.
2415     if (!ICVTrackingAA.isAssumedTracked())
2416       return indicatePessimisticFixpoint();
2417 
2418     Optional<Value *> NewReplVal =
2419         ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
2420 
2421     if (ReplVal == NewReplVal)
2422       return ChangeStatus::UNCHANGED;
2423 
2424     ReplVal = NewReplVal;
2425     return ChangeStatus::CHANGED;
2426   }
2427 
2428   // Return the value with which associated value can be replaced for specific
2429   // \p ICV.
2430   Optional<Value *>
2431   getUniqueReplacementValue(InternalControlVar ICV) const override {
2432     return ReplVal;
2433   }
2434 };
2435 
2436 struct AAICVTrackerCallSiteReturned : AAICVTracker {
2437   AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2438       : AAICVTracker(IRP, A) {}
2439 
2440   // FIXME: come up with better string.
2441   const std::string getAsStr() const override {
2442     return "ICVTrackerCallSiteReturned";
2443   }
2444 
2445   // FIXME: come up with some stats.
2446   void trackStatistics() const override {}
2447 
2448   /// We don't manifest anything for this AA.
2449   ChangeStatus manifest(Attributor &A) override {
2450     return ChangeStatus::UNCHANGED;
2451   }
2452 
2453   // Map of ICV to their values at specific program point.
2454   EnumeratedArray<Optional<Value *>, InternalControlVar,
2455                   InternalControlVar::ICV___last>
2456       ICVReplacementValuesMap;
2457 
2458   /// Return the value with which associated value can be replaced for specific
2459   /// \p ICV.
2460   Optional<Value *>
2461   getUniqueReplacementValue(InternalControlVar ICV) const override {
2462     return ICVReplacementValuesMap[ICV];
2463   }
2464 
2465   ChangeStatus updateImpl(Attributor &A) override {
2466     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2467     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2468         *this, IRPosition::returned(*getAssociatedFunction()),
2469         DepClassTy::REQUIRED);
2470 
2471     // We don't have any information, so we assume it changes the ICV.
2472     if (!ICVTrackingAA.isAssumedTracked())
2473       return indicatePessimisticFixpoint();
2474 
2475     for (InternalControlVar ICV : TrackableICVs) {
2476       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2477       Optional<Value *> NewReplVal =
2478           ICVTrackingAA.getUniqueReplacementValue(ICV);
2479 
2480       if (ReplVal == NewReplVal)
2481         continue;
2482 
2483       ReplVal = NewReplVal;
2484       Changed = ChangeStatus::CHANGED;
2485     }
2486     return Changed;
2487   }
2488 };
2489 
2490 struct AAExecutionDomainFunction : public AAExecutionDomain {
2491   AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2492       : AAExecutionDomain(IRP, A) {}
2493 
2494   const std::string getAsStr() const override {
2495     return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) +
2496            "/" + std::to_string(NumBBs) + " BBs thread 0 only.";
2497   }
2498 
2499   /// See AbstractAttribute::trackStatistics().
2500   void trackStatistics() const override {}
2501 
2502   void initialize(Attributor &A) override {
2503     Function *F = getAnchorScope();
2504     for (const auto &BB : *F)
2505       SingleThreadedBBs.insert(&BB);
2506     NumBBs = SingleThreadedBBs.size();
2507   }
2508 
2509   ChangeStatus manifest(Attributor &A) override {
2510     LLVM_DEBUG({
2511       for (const BasicBlock *BB : SingleThreadedBBs)
2512         dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2513                << BB->getName() << " is executed by a single thread.\n";
2514     });
2515     return ChangeStatus::UNCHANGED;
2516   }
2517 
2518   ChangeStatus updateImpl(Attributor &A) override;
2519 
2520   /// Check if an instruction is executed by a single thread.
2521   bool isExecutedByInitialThreadOnly(const Instruction &I) const override {
2522     return isExecutedByInitialThreadOnly(*I.getParent());
2523   }
2524 
2525   bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2526     return isValidState() && SingleThreadedBBs.contains(&BB);
2527   }
2528 
2529   /// Set of basic blocks that are executed by a single thread.
2530   DenseSet<const BasicBlock *> SingleThreadedBBs;
2531 
2532   /// Total number of basic blocks in this function.
2533   long unsigned NumBBs;
2534 };
2535 
2536 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
2537   Function *F = getAnchorScope();
2538   ReversePostOrderTraversal<Function *> RPOT(F);
2539   auto NumSingleThreadedBBs = SingleThreadedBBs.size();
2540 
2541   bool AllCallSitesKnown;
2542   auto PredForCallSite = [&](AbstractCallSite ACS) {
2543     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
2544         *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
2545         DepClassTy::REQUIRED);
2546     return ACS.isDirectCall() &&
2547            ExecutionDomainAA.isExecutedByInitialThreadOnly(
2548                *ACS.getInstruction());
2549   };
2550 
2551   if (!A.checkForAllCallSites(PredForCallSite, *this,
2552                               /* RequiresAllCallSites */ true,
2553                               AllCallSitesKnown))
2554     SingleThreadedBBs.erase(&F->getEntryBlock());
2555 
2556   auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2557   auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2558 
2559   // Check if the edge into the successor block contains a condition that only
2560   // lets the main thread execute it.
2561   auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
2562     if (!Edge || !Edge->isConditional())
2563       return false;
2564     if (Edge->getSuccessor(0) != SuccessorBB)
2565       return false;
2566 
2567     auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2568     if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2569       return false;
2570 
2571     ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2572     if (!C)
2573       return false;
2574 
2575     // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2576     if (C->isAllOnesValue()) {
2577       auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2578       CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2579       if (!CB)
2580         return false;
2581       const int InitIsSPMDArgNo = 1;
2582       auto *IsSPMDModeCI =
2583           dyn_cast<ConstantInt>(CB->getOperand(InitIsSPMDArgNo));
2584       return IsSPMDModeCI && IsSPMDModeCI->isZero();
2585     }
2586 
2587     if (C->isZero()) {
2588       // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2589       if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2590         if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2591           return true;
2592 
2593       // Match: 0 == llvm.amdgcn.workitem.id.x()
2594       if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2595         if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2596           return true;
2597     }
2598 
2599     return false;
2600   };
2601 
2602   // Merge all the predecessor states into the current basic block. A basic
2603   // block is executed by a single thread if all of its predecessors are.
2604   auto MergePredecessorStates = [&](BasicBlock *BB) {
2605     if (pred_begin(BB) == pred_end(BB))
2606       return SingleThreadedBBs.contains(BB);
2607 
2608     bool IsInitialThread = true;
2609     for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB);
2610          PredBB != PredEndBB; ++PredBB) {
2611       if (!IsInitialThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()),
2612                                BB))
2613         IsInitialThread &= SingleThreadedBBs.contains(*PredBB);
2614     }
2615 
2616     return IsInitialThread;
2617   };
2618 
2619   for (auto *BB : RPOT) {
2620     if (!MergePredecessorStates(BB))
2621       SingleThreadedBBs.erase(BB);
2622   }
2623 
2624   return (NumSingleThreadedBBs == SingleThreadedBBs.size())
2625              ? ChangeStatus::UNCHANGED
2626              : ChangeStatus::CHANGED;
2627 }
2628 
2629 /// Try to replace memory allocation calls called by a single thread with a
2630 /// static buffer of shared memory.
2631 struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
2632   using Base = StateWrapper<BooleanState, AbstractAttribute>;
2633   AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2634 
2635   /// Create an abstract attribute view for the position \p IRP.
2636   static AAHeapToShared &createForPosition(const IRPosition &IRP,
2637                                            Attributor &A);
2638 
2639   /// Returns true if HeapToShared conversion is assumed to be possible.
2640   virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
2641 
2642   /// Returns true if HeapToShared conversion is assumed and the CB is a
2643   /// callsite to a free operation to be removed.
2644   virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
2645 
2646   /// See AbstractAttribute::getName().
2647   const std::string getName() const override { return "AAHeapToShared"; }
2648 
2649   /// See AbstractAttribute::getIdAddr().
2650   const char *getIdAddr() const override { return &ID; }
2651 
2652   /// This function should return true if the type of the \p AA is
2653   /// AAHeapToShared.
2654   static bool classof(const AbstractAttribute *AA) {
2655     return (AA->getIdAddr() == &ID);
2656   }
2657 
2658   /// Unique ID (due to the unique address)
2659   static const char ID;
2660 };
2661 
2662 struct AAHeapToSharedFunction : public AAHeapToShared {
2663   AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
2664       : AAHeapToShared(IRP, A) {}
2665 
2666   const std::string getAsStr() const override {
2667     return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
2668            " malloc calls eligible.";
2669   }
2670 
2671   /// See AbstractAttribute::trackStatistics().
2672   void trackStatistics() const override {}
2673 
2674   /// This functions finds free calls that will be removed by the
2675   /// HeapToShared transformation.
2676   void findPotentialRemovedFreeCalls(Attributor &A) {
2677     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2678     auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2679 
2680     PotentialRemovedFreeCalls.clear();
2681     // Update free call users of found malloc calls.
2682     for (CallBase *CB : MallocCalls) {
2683       SmallVector<CallBase *, 4> FreeCalls;
2684       for (auto *U : CB->users()) {
2685         CallBase *C = dyn_cast<CallBase>(U);
2686         if (C && C->getCalledFunction() == FreeRFI.Declaration)
2687           FreeCalls.push_back(C);
2688       }
2689 
2690       if (FreeCalls.size() != 1)
2691         continue;
2692 
2693       PotentialRemovedFreeCalls.insert(FreeCalls.front());
2694     }
2695   }
2696 
2697   void initialize(Attributor &A) override {
2698     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2699     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2700 
2701     for (User *U : RFI.Declaration->users())
2702       if (CallBase *CB = dyn_cast<CallBase>(U))
2703         MallocCalls.insert(CB);
2704 
2705     findPotentialRemovedFreeCalls(A);
2706   }
2707 
2708   bool isAssumedHeapToShared(CallBase &CB) const override {
2709     return isValidState() && MallocCalls.count(&CB);
2710   }
2711 
2712   bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
2713     return isValidState() && PotentialRemovedFreeCalls.count(&CB);
2714   }
2715 
2716   ChangeStatus manifest(Attributor &A) override {
2717     if (MallocCalls.empty())
2718       return ChangeStatus::UNCHANGED;
2719 
2720     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2721     auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2722 
2723     Function *F = getAnchorScope();
2724     auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
2725                                             DepClassTy::OPTIONAL);
2726 
2727     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2728     for (CallBase *CB : MallocCalls) {
2729       // Skip replacing this if HeapToStack has already claimed it.
2730       if (HS && HS->isAssumedHeapToStack(*CB))
2731         continue;
2732 
2733       // Find the unique free call to remove it.
2734       SmallVector<CallBase *, 4> FreeCalls;
2735       for (auto *U : CB->users()) {
2736         CallBase *C = dyn_cast<CallBase>(U);
2737         if (C && C->getCalledFunction() == FreeCall.Declaration)
2738           FreeCalls.push_back(C);
2739       }
2740       if (FreeCalls.size() != 1)
2741         continue;
2742 
2743       ConstantInt *AllocSize = dyn_cast<ConstantInt>(CB->getArgOperand(0));
2744 
2745       LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
2746                         << " with " << AllocSize->getZExtValue()
2747                         << " bytes of shared memory\n");
2748 
2749       // Create a new shared memory buffer of the same size as the allocation
2750       // and replace all the uses of the original allocation with it.
2751       Module *M = CB->getModule();
2752       Type *Int8Ty = Type::getInt8Ty(M->getContext());
2753       Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
2754       auto *SharedMem = new GlobalVariable(
2755           *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
2756           UndefValue::get(Int8ArrTy), CB->getName(), nullptr,
2757           GlobalValue::NotThreadLocal,
2758           static_cast<unsigned>(AddressSpace::Shared));
2759       auto *NewBuffer =
2760           ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
2761 
2762       auto Remark = [&](OptimizationRemark OR) {
2763         return OR << "Replaced globalized variable with "
2764                   << ore::NV("SharedMemory", AllocSize->getZExtValue())
2765                   << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ")
2766                   << "of shared memory.";
2767       };
2768       A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
2769 
2770       SharedMem->setAlignment(MaybeAlign(32));
2771 
2772       A.changeValueAfterManifest(*CB, *NewBuffer);
2773       A.deleteAfterManifest(*CB);
2774       A.deleteAfterManifest(*FreeCalls.front());
2775 
2776       NumBytesMovedToSharedMemory += AllocSize->getZExtValue();
2777       Changed = ChangeStatus::CHANGED;
2778     }
2779 
2780     return Changed;
2781   }
2782 
2783   ChangeStatus updateImpl(Attributor &A) override {
2784     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2785     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2786     Function *F = getAnchorScope();
2787 
2788     auto NumMallocCalls = MallocCalls.size();
2789 
2790     // Only consider malloc calls executed by a single thread with a constant.
2791     for (User *U : RFI.Declaration->users()) {
2792       const auto &ED = A.getAAFor<AAExecutionDomain>(
2793           *this, IRPosition::function(*F), DepClassTy::REQUIRED);
2794       if (CallBase *CB = dyn_cast<CallBase>(U))
2795         if (!dyn_cast<ConstantInt>(CB->getArgOperand(0)) ||
2796             !ED.isExecutedByInitialThreadOnly(*CB))
2797           MallocCalls.erase(CB);
2798     }
2799 
2800     findPotentialRemovedFreeCalls(A);
2801 
2802     if (NumMallocCalls != MallocCalls.size())
2803       return ChangeStatus::CHANGED;
2804 
2805     return ChangeStatus::UNCHANGED;
2806   }
2807 
2808   /// Collection of all malloc calls in a function.
2809   SmallPtrSet<CallBase *, 4> MallocCalls;
2810   /// Collection of potentially removed free calls in a function.
2811   SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
2812 };
2813 
2814 struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
2815   using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
2816   AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2817 
2818   /// Statistics are tracked as part of manifest for now.
2819   void trackStatistics() const override {}
2820 
2821   /// See AbstractAttribute::getAsStr()
2822   const std::string getAsStr() const override {
2823     if (!isValidState())
2824       return "<invalid>";
2825     return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
2826                                                             : "generic") +
2827            std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
2828                                                                : "") +
2829            std::string(" #PRs: ") +
2830            std::to_string(ReachedKnownParallelRegions.size()) +
2831            ", #Unknown PRs: " +
2832            std::to_string(ReachedUnknownParallelRegions.size());
2833   }
2834 
2835   /// Create an abstract attribute biew for the position \p IRP.
2836   static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
2837 
2838   /// See AbstractAttribute::getName()
2839   const std::string getName() const override { return "AAKernelInfo"; }
2840 
2841   /// See AbstractAttribute::getIdAddr()
2842   const char *getIdAddr() const override { return &ID; }
2843 
2844   /// This function should return true if the type of the \p AA is AAKernelInfo
2845   static bool classof(const AbstractAttribute *AA) {
2846     return (AA->getIdAddr() == &ID);
2847   }
2848 
2849   static const char ID;
2850 };
2851 
2852 /// The function kernel info abstract attribute, basically, what can we say
2853 /// about a function with regards to the KernelInfoState.
2854 struct AAKernelInfoFunction : AAKernelInfo {
2855   AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
2856       : AAKernelInfo(IRP, A) {}
2857 
2858   SmallPtrSet<Instruction *, 4> GuardedInstructions;
2859 
2860   SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
2861     return GuardedInstructions;
2862   }
2863 
2864   /// See AbstractAttribute::initialize(...).
2865   void initialize(Attributor &A) override {
2866     // This is a high-level transform that might change the constant arguments
2867     // of the init and dinit calls. We need to tell the Attributor about this
2868     // to avoid other parts using the current constant value for simpliication.
2869     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2870 
2871     Function *Fn = getAnchorScope();
2872     if (!OMPInfoCache.Kernels.count(Fn))
2873       return;
2874 
2875     // Add itself to the reaching kernel and set IsKernelEntry.
2876     ReachingKernelEntries.insert(Fn);
2877     IsKernelEntry = true;
2878 
2879     OMPInformationCache::RuntimeFunctionInfo &InitRFI =
2880         OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2881     OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
2882         OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
2883 
2884     // For kernels we perform more initialization work, first we find the init
2885     // and deinit calls.
2886     auto StoreCallBase = [](Use &U,
2887                             OMPInformationCache::RuntimeFunctionInfo &RFI,
2888                             CallBase *&Storage) {
2889       CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
2890       assert(CB &&
2891              "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
2892       assert(!Storage &&
2893              "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
2894       Storage = CB;
2895       return false;
2896     };
2897     InitRFI.foreachUse(
2898         [&](Use &U, Function &) {
2899           StoreCallBase(U, InitRFI, KernelInitCB);
2900           return false;
2901         },
2902         Fn);
2903     DeinitRFI.foreachUse(
2904         [&](Use &U, Function &) {
2905           StoreCallBase(U, DeinitRFI, KernelDeinitCB);
2906           return false;
2907         },
2908         Fn);
2909 
2910     // Ignore kernels without initializers such as global constructors.
2911     if (!KernelInitCB || !KernelDeinitCB) {
2912       indicateOptimisticFixpoint();
2913       return;
2914     }
2915 
2916     // For kernels we might need to initialize/finalize the IsSPMD state and
2917     // we need to register a simplification callback so that the Attributor
2918     // knows the constant arguments to __kmpc_target_init and
2919     // __kmpc_target_deinit might actually change.
2920 
2921     Attributor::SimplifictionCallbackTy StateMachineSimplifyCB =
2922         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2923             bool &UsedAssumedInformation) -> Optional<Value *> {
2924       // IRP represents the "use generic state machine" argument of an
2925       // __kmpc_target_init call. We will answer this one with the internal
2926       // state. As long as we are not in an invalid state, we will create a
2927       // custom state machine so the value should be a `i1 false`. If we are
2928       // in an invalid state, we won't change the value that is in the IR.
2929       if (!isValidState())
2930         return nullptr;
2931       // If we have disabled state machine rewrites, don't make a custom one.
2932       if (DisableOpenMPOptStateMachineRewrite)
2933         return nullptr;
2934       if (AA)
2935         A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2936       UsedAssumedInformation = !isAtFixpoint();
2937       auto *FalseVal =
2938           ConstantInt::getBool(IRP.getAnchorValue().getContext(), 0);
2939       return FalseVal;
2940     };
2941 
2942     Attributor::SimplifictionCallbackTy IsSPMDModeSimplifyCB =
2943         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2944             bool &UsedAssumedInformation) -> Optional<Value *> {
2945       // IRP represents the "SPMDCompatibilityTracker" argument of an
2946       // __kmpc_target_init or
2947       // __kmpc_target_deinit call. We will answer this one with the internal
2948       // state.
2949       if (!SPMDCompatibilityTracker.isValidState())
2950         return nullptr;
2951       if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2952         if (AA)
2953           A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2954         UsedAssumedInformation = true;
2955       } else {
2956         UsedAssumedInformation = false;
2957       }
2958       auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
2959                                        SPMDCompatibilityTracker.isAssumed());
2960       return Val;
2961     };
2962 
2963     Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB =
2964         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2965             bool &UsedAssumedInformation) -> Optional<Value *> {
2966       // IRP represents the "RequiresFullRuntime" argument of an
2967       // __kmpc_target_init or __kmpc_target_deinit call. We will answer this
2968       // one with the internal state of the SPMDCompatibilityTracker, so if
2969       // generic then true, if SPMD then false.
2970       if (!SPMDCompatibilityTracker.isValidState())
2971         return nullptr;
2972       if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2973         if (AA)
2974           A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2975         UsedAssumedInformation = true;
2976       } else {
2977         UsedAssumedInformation = false;
2978       }
2979       auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
2980                                        !SPMDCompatibilityTracker.isAssumed());
2981       return Val;
2982     };
2983 
2984     constexpr const int InitIsSPMDArgNo = 1;
2985     constexpr const int DeinitIsSPMDArgNo = 1;
2986     constexpr const int InitUseStateMachineArgNo = 2;
2987     constexpr const int InitRequiresFullRuntimeArgNo = 3;
2988     constexpr const int DeinitRequiresFullRuntimeArgNo = 2;
2989     A.registerSimplificationCallback(
2990         IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo),
2991         StateMachineSimplifyCB);
2992     A.registerSimplificationCallback(
2993         IRPosition::callsite_argument(*KernelInitCB, InitIsSPMDArgNo),
2994         IsSPMDModeSimplifyCB);
2995     A.registerSimplificationCallback(
2996         IRPosition::callsite_argument(*KernelDeinitCB, DeinitIsSPMDArgNo),
2997         IsSPMDModeSimplifyCB);
2998     A.registerSimplificationCallback(
2999         IRPosition::callsite_argument(*KernelInitCB,
3000                                       InitRequiresFullRuntimeArgNo),
3001         IsGenericModeSimplifyCB);
3002     A.registerSimplificationCallback(
3003         IRPosition::callsite_argument(*KernelDeinitCB,
3004                                       DeinitRequiresFullRuntimeArgNo),
3005         IsGenericModeSimplifyCB);
3006 
3007     // Check if we know we are in SPMD-mode already.
3008     ConstantInt *IsSPMDArg =
3009         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo));
3010     if (IsSPMDArg && !IsSPMDArg->isZero())
3011       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3012     // This is a generic region but SPMDization is disabled so stop tracking.
3013     else if (DisableOpenMPOptSPMDization)
3014       SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3015   }
3016 
3017   /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3018   /// finished now.
3019   ChangeStatus manifest(Attributor &A) override {
3020     // If we are not looking at a kernel with __kmpc_target_init and
3021     // __kmpc_target_deinit call we cannot actually manifest the information.
3022     if (!KernelInitCB || !KernelDeinitCB)
3023       return ChangeStatus::UNCHANGED;
3024 
3025     // Known SPMD-mode kernels need no manifest changes.
3026     if (SPMDCompatibilityTracker.isKnown())
3027       return ChangeStatus::UNCHANGED;
3028 
3029     // If we can we change the execution mode to SPMD-mode otherwise we build a
3030     // custom state machine.
3031     if (!mayContainParallelRegion() || !changeToSPMDMode(A))
3032       buildCustomStateMachine(A);
3033 
3034     return ChangeStatus::CHANGED;
3035   }
3036 
3037   bool changeToSPMDMode(Attributor &A) {
3038     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3039 
3040     if (!SPMDCompatibilityTracker.isAssumed()) {
3041       for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
3042         if (!NonCompatibleI)
3043           continue;
3044 
3045         // Skip diagnostics on calls to known OpenMP runtime functions for now.
3046         if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
3047           if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
3048             continue;
3049 
3050         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3051           ORA << "Value has potential side effects preventing SPMD-mode "
3052                  "execution";
3053           if (isa<CallBase>(NonCompatibleI)) {
3054             ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
3055                    "the called function to override";
3056           }
3057           return ORA << ".";
3058         };
3059         A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
3060                                                  Remark);
3061 
3062         LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
3063                           << *NonCompatibleI << "\n");
3064       }
3065 
3066       return false;
3067     }
3068 
3069     auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3070                                    Instruction *RegionEndI) {
3071       LoopInfo *LI = nullptr;
3072       DominatorTree *DT = nullptr;
3073       MemorySSAUpdater *MSU = nullptr;
3074       using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3075 
3076       BasicBlock *ParentBB = RegionStartI->getParent();
3077       Function *Fn = ParentBB->getParent();
3078       Module &M = *Fn->getParent();
3079 
3080       // Create all the blocks and logic.
3081       // ParentBB:
3082       //    goto RegionCheckTidBB
3083       // RegionCheckTidBB:
3084       //    Tid = __kmpc_hardware_thread_id()
3085       //    if (Tid != 0)
3086       //        goto RegionBarrierBB
3087       // RegionStartBB:
3088       //    <execute instructions guarded>
3089       //    goto RegionEndBB
3090       // RegionEndBB:
3091       //    <store escaping values to shared mem>
3092       //    goto RegionBarrierBB
3093       //  RegionBarrierBB:
3094       //    __kmpc_simple_barrier_spmd()
3095       //    // second barrier is omitted if lacking escaping values.
3096       //    <load escaping values from shared mem>
3097       //    __kmpc_simple_barrier_spmd()
3098       //    goto RegionExitBB
3099       // RegionExitBB:
3100       //    <execute rest of instructions>
3101 
3102       BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3103                                            DT, LI, MSU, "region.guarded.end");
3104       BasicBlock *RegionBarrierBB =
3105           SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3106                      MSU, "region.barrier");
3107       BasicBlock *RegionExitBB =
3108           SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3109                      DT, LI, MSU, "region.exit");
3110       BasicBlock *RegionStartBB =
3111           SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3112 
3113       assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3114              "Expected a different CFG");
3115 
3116       BasicBlock *RegionCheckTidBB = SplitBlock(
3117           ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3118 
3119       // Register basic blocks with the Attributor.
3120       A.registerManifestAddedBasicBlock(*RegionEndBB);
3121       A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3122       A.registerManifestAddedBasicBlock(*RegionExitBB);
3123       A.registerManifestAddedBasicBlock(*RegionStartBB);
3124       A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
3125 
3126       bool HasBroadcastValues = false;
3127       // Find escaping outputs from the guarded region to outside users and
3128       // broadcast their values to them.
3129       for (Instruction &I : *RegionStartBB) {
3130         SmallPtrSet<Instruction *, 4> OutsideUsers;
3131         for (User *Usr : I.users()) {
3132           Instruction &UsrI = *cast<Instruction>(Usr);
3133           if (UsrI.getParent() != RegionStartBB)
3134             OutsideUsers.insert(&UsrI);
3135         }
3136 
3137         if (OutsideUsers.empty())
3138           continue;
3139 
3140         HasBroadcastValues = true;
3141 
3142         // Emit a global variable in shared memory to store the broadcasted
3143         // value.
3144         auto *SharedMem = new GlobalVariable(
3145             M, I.getType(), /* IsConstant */ false,
3146             GlobalValue::InternalLinkage, UndefValue::get(I.getType()),
3147             I.getName() + ".guarded.output.alloc", nullptr,
3148             GlobalValue::NotThreadLocal,
3149             static_cast<unsigned>(AddressSpace::Shared));
3150 
3151         // Emit a store instruction to update the value.
3152         new StoreInst(&I, SharedMem, RegionEndBB->getTerminator());
3153 
3154         LoadInst *LoadI = new LoadInst(I.getType(), SharedMem,
3155                                        I.getName() + ".guarded.output.load",
3156                                        RegionBarrierBB->getTerminator());
3157 
3158         // Emit a load instruction and replace uses of the output value.
3159         for (Instruction *UsrI : OutsideUsers) {
3160           assert(UsrI->getParent() == RegionExitBB &&
3161                  "Expected escaping users in exit region");
3162           UsrI->replaceUsesOfWith(&I, LoadI);
3163         }
3164       }
3165 
3166       auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3167 
3168       // Go to tid check BB in ParentBB.
3169       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
3170       ParentBB->getTerminator()->eraseFromParent();
3171       OpenMPIRBuilder::LocationDescription Loc(
3172           InsertPointTy(ParentBB, ParentBB->end()), DL);
3173       OMPInfoCache.OMPBuilder.updateToLocation(Loc);
3174       auto *SrcLocStr = OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc);
3175       Value *Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr);
3176       BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
3177 
3178       // Add check for Tid in RegionCheckTidBB
3179       RegionCheckTidBB->getTerminator()->eraseFromParent();
3180       OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
3181           InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
3182       OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
3183       FunctionCallee HardwareTidFn =
3184           OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3185               M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3186       Value *Tid =
3187           OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
3188       Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
3189       OMPInfoCache.OMPBuilder.Builder
3190           .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
3191           ->setDebugLoc(DL);
3192 
3193       // First barrier for synchronization, ensures main thread has updated
3194       // values.
3195       FunctionCallee BarrierFn =
3196           OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3197               M, OMPRTL___kmpc_barrier_simple_spmd);
3198       OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
3199           RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
3200       OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid})
3201           ->setDebugLoc(DL);
3202 
3203       // Second barrier ensures workers have read broadcast values.
3204       if (HasBroadcastValues)
3205         CallInst::Create(BarrierFn, {Ident, Tid}, "",
3206                          RegionBarrierBB->getTerminator())
3207             ->setDebugLoc(DL);
3208     };
3209 
3210     auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3211     SmallPtrSet<BasicBlock *, 8> Visited;
3212     for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3213       BasicBlock *BB = GuardedI->getParent();
3214       if (!Visited.insert(BB).second)
3215         continue;
3216 
3217       SmallVector<std::pair<Instruction *, Instruction *>> Reorders;
3218       Instruction *LastEffect = nullptr;
3219       BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
3220       while (++IP != IPEnd) {
3221         if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
3222           continue;
3223         Instruction *I = &*IP;
3224         if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
3225           continue;
3226         if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
3227           LastEffect = nullptr;
3228           continue;
3229         }
3230         if (LastEffect)
3231           Reorders.push_back({I, LastEffect});
3232         LastEffect = &*IP;
3233       }
3234       for (auto &Reorder : Reorders)
3235         Reorder.first->moveBefore(Reorder.second);
3236     }
3237 
3238     SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions;
3239 
3240     for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3241       BasicBlock *BB = GuardedI->getParent();
3242       auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
3243           IRPosition::function(*GuardedI->getFunction()), nullptr,
3244           DepClassTy::NONE);
3245       assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
3246       auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
3247       // Continue if instruction is already guarded.
3248       if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
3249         continue;
3250 
3251       Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
3252       for (Instruction &I : *BB) {
3253         // If instruction I needs to be guarded update the guarded region
3254         // bounds.
3255         if (SPMDCompatibilityTracker.contains(&I)) {
3256           CalleeAAFunction.getGuardedInstructions().insert(&I);
3257           if (GuardedRegionStart)
3258             GuardedRegionEnd = &I;
3259           else
3260             GuardedRegionStart = GuardedRegionEnd = &I;
3261 
3262           continue;
3263         }
3264 
3265         // Instruction I does not need guarding, store
3266         // any region found and reset bounds.
3267         if (GuardedRegionStart) {
3268           GuardedRegions.push_back(
3269               std::make_pair(GuardedRegionStart, GuardedRegionEnd));
3270           GuardedRegionStart = nullptr;
3271           GuardedRegionEnd = nullptr;
3272         }
3273       }
3274     }
3275 
3276     for (auto &GR : GuardedRegions)
3277       CreateGuardedRegion(GR.first, GR.second);
3278 
3279     // Adjust the global exec mode flag that tells the runtime what mode this
3280     // kernel is executed in.
3281     Function *Kernel = getAnchorScope();
3282     GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable(
3283         (Kernel->getName() + "_exec_mode").str());
3284     assert(ExecMode && "Kernel without exec mode?");
3285     assert(ExecMode->getInitializer() &&
3286            ExecMode->getInitializer()->isOneValue() &&
3287            "Initially non-SPMD kernel has SPMD exec mode!");
3288 
3289     // Set the global exec mode flag to indicate SPMD-Generic mode.
3290     constexpr int SPMDGeneric = 2;
3291     if (!ExecMode->getInitializer()->isZeroValue())
3292       ExecMode->setInitializer(
3293           ConstantInt::get(ExecMode->getInitializer()->getType(), SPMDGeneric));
3294 
3295     // Next rewrite the init and deinit calls to indicate we use SPMD-mode now.
3296     const int InitIsSPMDArgNo = 1;
3297     const int DeinitIsSPMDArgNo = 1;
3298     const int InitUseStateMachineArgNo = 2;
3299     const int InitRequiresFullRuntimeArgNo = 3;
3300     const int DeinitRequiresFullRuntimeArgNo = 2;
3301 
3302     auto &Ctx = getAnchorValue().getContext();
3303     A.changeUseAfterManifest(KernelInitCB->getArgOperandUse(InitIsSPMDArgNo),
3304                              *ConstantInt::getBool(Ctx, 1));
3305     A.changeUseAfterManifest(
3306         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
3307         *ConstantInt::getBool(Ctx, 0));
3308     A.changeUseAfterManifest(
3309         KernelDeinitCB->getArgOperandUse(DeinitIsSPMDArgNo),
3310         *ConstantInt::getBool(Ctx, 1));
3311     A.changeUseAfterManifest(
3312         KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo),
3313         *ConstantInt::getBool(Ctx, 0));
3314     A.changeUseAfterManifest(
3315         KernelDeinitCB->getArgOperandUse(DeinitRequiresFullRuntimeArgNo),
3316         *ConstantInt::getBool(Ctx, 0));
3317 
3318     ++NumOpenMPTargetRegionKernelsSPMD;
3319 
3320     auto Remark = [&](OptimizationRemark OR) {
3321       return OR << "Transformed generic-mode kernel to SPMD-mode.";
3322     };
3323     A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
3324     return true;
3325   };
3326 
3327   ChangeStatus buildCustomStateMachine(Attributor &A) {
3328     // If we have disabled state machine rewrites, don't make a custom one
3329     if (DisableOpenMPOptStateMachineRewrite)
3330       return indicatePessimisticFixpoint();
3331 
3332     assert(ReachedKnownParallelRegions.isValidState() &&
3333            "Custom state machine with invalid parallel region states?");
3334 
3335     const int InitIsSPMDArgNo = 1;
3336     const int InitUseStateMachineArgNo = 2;
3337 
3338     // Check if the current configuration is non-SPMD and generic state machine.
3339     // If we already have SPMD mode or a custom state machine we do not need to
3340     // go any further. If it is anything but a constant something is weird and
3341     // we give up.
3342     ConstantInt *UseStateMachine = dyn_cast<ConstantInt>(
3343         KernelInitCB->getArgOperand(InitUseStateMachineArgNo));
3344     ConstantInt *IsSPMD =
3345         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo));
3346 
3347     // If we are stuck with generic mode, try to create a custom device (=GPU)
3348     // state machine which is specialized for the parallel regions that are
3349     // reachable by the kernel.
3350     if (!UseStateMachine || UseStateMachine->isZero() || !IsSPMD ||
3351         !IsSPMD->isZero())
3352       return ChangeStatus::UNCHANGED;
3353 
3354     // If not SPMD mode, indicate we use a custom state machine now.
3355     auto &Ctx = getAnchorValue().getContext();
3356     auto *FalseVal = ConstantInt::getBool(Ctx, 0);
3357     A.changeUseAfterManifest(
3358         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal);
3359 
3360     // If we don't actually need a state machine we are done here. This can
3361     // happen if there simply are no parallel regions. In the resulting kernel
3362     // all worker threads will simply exit right away, leaving the main thread
3363     // to do the work alone.
3364     if (!mayContainParallelRegion()) {
3365       ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
3366 
3367       auto Remark = [&](OptimizationRemark OR) {
3368         return OR << "Removing unused state machine from generic-mode kernel.";
3369       };
3370       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
3371 
3372       return ChangeStatus::CHANGED;
3373     }
3374 
3375     // Keep track in the statistics of our new shiny custom state machine.
3376     if (ReachedUnknownParallelRegions.empty()) {
3377       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
3378 
3379       auto Remark = [&](OptimizationRemark OR) {
3380         return OR << "Rewriting generic-mode kernel with a customized state "
3381                      "machine.";
3382       };
3383       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
3384     } else {
3385       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
3386 
3387       auto Remark = [&](OptimizationRemarkAnalysis OR) {
3388         return OR << "Generic-mode kernel is executed with a customized state "
3389                      "machine that requires a fallback.";
3390       };
3391       A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
3392 
3393       // Tell the user why we ended up with a fallback.
3394       for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
3395         if (!UnknownParallelRegionCB)
3396           continue;
3397         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3398           return ORA << "Call may contain unknown parallel regions. Use "
3399                      << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
3400                         "override.";
3401         };
3402         A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
3403                                                  "OMP133", Remark);
3404       }
3405     }
3406 
3407     // Create all the blocks:
3408     //
3409     //                       InitCB = __kmpc_target_init(...)
3410     //                       bool IsWorker = InitCB >= 0;
3411     //                       if (IsWorker) {
3412     // SMBeginBB:               __kmpc_barrier_simple_spmd(...);
3413     //                         void *WorkFn;
3414     //                         bool Active = __kmpc_kernel_parallel(&WorkFn);
3415     //                         if (!WorkFn) return;
3416     // SMIsActiveCheckBB:       if (Active) {
3417     // SMIfCascadeCurrentBB:      if      (WorkFn == <ParFn0>)
3418     //                              ParFn0(...);
3419     // SMIfCascadeCurrentBB:      else if (WorkFn == <ParFn1>)
3420     //                              ParFn1(...);
3421     //                            ...
3422     // SMIfCascadeCurrentBB:      else
3423     //                              ((WorkFnTy*)WorkFn)(...);
3424     // SMEndParallelBB:           __kmpc_kernel_end_parallel(...);
3425     //                          }
3426     // SMDoneBB:                __kmpc_barrier_simple_spmd(...);
3427     //                          goto SMBeginBB;
3428     //                       }
3429     // UserCodeEntryBB:      // user code
3430     //                       __kmpc_target_deinit(...)
3431     //
3432     Function *Kernel = getAssociatedFunction();
3433     assert(Kernel && "Expected an associated function!");
3434 
3435     BasicBlock *InitBB = KernelInitCB->getParent();
3436     BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
3437         KernelInitCB->getNextNode(), "thread.user_code.check");
3438     BasicBlock *StateMachineBeginBB = BasicBlock::Create(
3439         Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
3440     BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
3441         Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
3442     BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
3443         Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
3444     BasicBlock *StateMachineIfCascadeCurrentBB =
3445         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3446                            Kernel, UserCodeEntryBB);
3447     BasicBlock *StateMachineEndParallelBB =
3448         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
3449                            Kernel, UserCodeEntryBB);
3450     BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
3451         Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
3452     A.registerManifestAddedBasicBlock(*InitBB);
3453     A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
3454     A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
3455     A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
3456     A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
3457     A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
3458     A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
3459     A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
3460 
3461     const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
3462     ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
3463 
3464     InitBB->getTerminator()->eraseFromParent();
3465     Instruction *IsWorker =
3466         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
3467                          ConstantInt::get(KernelInitCB->getType(), -1),
3468                          "thread.is_worker", InitBB);
3469     IsWorker->setDebugLoc(DLoc);
3470     BranchInst::Create(StateMachineBeginBB, UserCodeEntryBB, IsWorker, InitBB);
3471 
3472     Module &M = *Kernel->getParent();
3473 
3474     // Create local storage for the work function pointer.
3475     const DataLayout &DL = M.getDataLayout();
3476     Type *VoidPtrTy = Type::getInt8PtrTy(Ctx);
3477     Instruction *WorkFnAI =
3478         new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
3479                        "worker.work_fn.addr", &Kernel->getEntryBlock().front());
3480     WorkFnAI->setDebugLoc(DLoc);
3481 
3482     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3483     OMPInfoCache.OMPBuilder.updateToLocation(
3484         OpenMPIRBuilder::LocationDescription(
3485             IRBuilder<>::InsertPoint(StateMachineBeginBB,
3486                                      StateMachineBeginBB->end()),
3487             DLoc));
3488 
3489     Value *Ident = KernelInitCB->getArgOperand(0);
3490     Value *GTid = KernelInitCB;
3491 
3492     FunctionCallee BarrierFn =
3493         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3494             M, OMPRTL___kmpc_barrier_simple_spmd);
3495     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB)
3496         ->setDebugLoc(DLoc);
3497 
3498     if (WorkFnAI->getType()->getPointerAddressSpace() !=
3499         (unsigned int)AddressSpace::Generic) {
3500       WorkFnAI = new AddrSpaceCastInst(
3501           WorkFnAI,
3502           PointerType::getWithSamePointeeType(
3503               cast<PointerType>(WorkFnAI->getType()),
3504               (unsigned int)AddressSpace::Generic),
3505           WorkFnAI->getName() + ".generic", StateMachineBeginBB);
3506       WorkFnAI->setDebugLoc(DLoc);
3507     }
3508 
3509     FunctionCallee KernelParallelFn =
3510         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3511             M, OMPRTL___kmpc_kernel_parallel);
3512     Instruction *IsActiveWorker = CallInst::Create(
3513         KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
3514     IsActiveWorker->setDebugLoc(DLoc);
3515     Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
3516                                        StateMachineBeginBB);
3517     WorkFn->setDebugLoc(DLoc);
3518 
3519     FunctionType *ParallelRegionFnTy = FunctionType::get(
3520         Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
3521         false);
3522     Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast(
3523         WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast",
3524         StateMachineBeginBB);
3525 
3526     Instruction *IsDone =
3527         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
3528                          Constant::getNullValue(VoidPtrTy), "worker.is_done",
3529                          StateMachineBeginBB);
3530     IsDone->setDebugLoc(DLoc);
3531     BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
3532                        IsDone, StateMachineBeginBB)
3533         ->setDebugLoc(DLoc);
3534 
3535     BranchInst::Create(StateMachineIfCascadeCurrentBB,
3536                        StateMachineDoneBarrierBB, IsActiveWorker,
3537                        StateMachineIsActiveCheckBB)
3538         ->setDebugLoc(DLoc);
3539 
3540     Value *ZeroArg =
3541         Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
3542 
3543     // Now that we have most of the CFG skeleton it is time for the if-cascade
3544     // that checks the function pointer we got from the runtime against the
3545     // parallel regions we expect, if there are any.
3546     for (int i = 0, e = ReachedKnownParallelRegions.size(); i < e; ++i) {
3547       auto *ParallelRegion = ReachedKnownParallelRegions[i];
3548       BasicBlock *PRExecuteBB = BasicBlock::Create(
3549           Ctx, "worker_state_machine.parallel_region.execute", Kernel,
3550           StateMachineEndParallelBB);
3551       CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
3552           ->setDebugLoc(DLoc);
3553       BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
3554           ->setDebugLoc(DLoc);
3555 
3556       BasicBlock *PRNextBB =
3557           BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3558                              Kernel, StateMachineEndParallelBB);
3559 
3560       // Check if we need to compare the pointer at all or if we can just
3561       // call the parallel region function.
3562       Value *IsPR;
3563       if (i + 1 < e || !ReachedUnknownParallelRegions.empty()) {
3564         Instruction *CmpI = ICmpInst::Create(
3565             ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion,
3566             "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
3567         CmpI->setDebugLoc(DLoc);
3568         IsPR = CmpI;
3569       } else {
3570         IsPR = ConstantInt::getTrue(Ctx);
3571       }
3572 
3573       BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
3574                          StateMachineIfCascadeCurrentBB)
3575           ->setDebugLoc(DLoc);
3576       StateMachineIfCascadeCurrentBB = PRNextBB;
3577     }
3578 
3579     // At the end of the if-cascade we place the indirect function pointer call
3580     // in case we might need it, that is if there can be parallel regions we
3581     // have not handled in the if-cascade above.
3582     if (!ReachedUnknownParallelRegions.empty()) {
3583       StateMachineIfCascadeCurrentBB->setName(
3584           "worker_state_machine.parallel_region.fallback.execute");
3585       CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "",
3586                        StateMachineIfCascadeCurrentBB)
3587           ->setDebugLoc(DLoc);
3588     }
3589     BranchInst::Create(StateMachineEndParallelBB,
3590                        StateMachineIfCascadeCurrentBB)
3591         ->setDebugLoc(DLoc);
3592 
3593     CallInst::Create(OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3594                          M, OMPRTL___kmpc_kernel_end_parallel),
3595                      {}, "", StateMachineEndParallelBB)
3596         ->setDebugLoc(DLoc);
3597     BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
3598         ->setDebugLoc(DLoc);
3599 
3600     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
3601         ->setDebugLoc(DLoc);
3602     BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
3603         ->setDebugLoc(DLoc);
3604 
3605     return ChangeStatus::CHANGED;
3606   }
3607 
3608   /// Fixpoint iteration update function. Will be called every time a dependence
3609   /// changed its state (and in the beginning).
3610   ChangeStatus updateImpl(Attributor &A) override {
3611     KernelInfoState StateBefore = getState();
3612 
3613     // Callback to check a read/write instruction.
3614     auto CheckRWInst = [&](Instruction &I) {
3615       // We handle calls later.
3616       if (isa<CallBase>(I))
3617         return true;
3618       // We only care about write effects.
3619       if (!I.mayWriteToMemory())
3620         return true;
3621       if (auto *SI = dyn_cast<StoreInst>(&I)) {
3622         SmallVector<const Value *> Objects;
3623         getUnderlyingObjects(SI->getPointerOperand(), Objects);
3624         if (llvm::all_of(Objects,
3625                          [](const Value *Obj) { return isa<AllocaInst>(Obj); }))
3626           return true;
3627         // Check for AAHeapToStack moved objects which must not be guarded.
3628         auto &HS = A.getAAFor<AAHeapToStack>(
3629             *this, IRPosition::function(*I.getFunction()),
3630             DepClassTy::REQUIRED);
3631         if (llvm::all_of(Objects, [&HS](const Value *Obj) {
3632               auto *CB = dyn_cast<CallBase>(Obj);
3633               if (!CB)
3634                 return false;
3635               return HS.isAssumedHeapToStack(*CB);
3636             })) {
3637           return true;
3638         }
3639       }
3640 
3641       // Insert instruction that needs guarding.
3642       SPMDCompatibilityTracker.insert(&I);
3643       return true;
3644     };
3645 
3646     bool UsedAssumedInformationInCheckRWInst = false;
3647     if (!SPMDCompatibilityTracker.isAtFixpoint())
3648       if (!A.checkForAllReadWriteInstructions(
3649               CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
3650         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3651 
3652     if (!IsKernelEntry) {
3653       updateReachingKernelEntries(A);
3654       updateParallelLevels(A);
3655 
3656       if (!ParallelLevels.isValidState())
3657         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3658     }
3659 
3660     // Callback to check a call instruction.
3661     bool AllSPMDStatesWereFixed = true;
3662     auto CheckCallInst = [&](Instruction &I) {
3663       auto &CB = cast<CallBase>(I);
3664       auto &CBAA = A.getAAFor<AAKernelInfo>(
3665           *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
3666       getState() ^= CBAA.getState();
3667       AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint();
3668       return true;
3669     };
3670 
3671     bool UsedAssumedInformationInCheckCallInst = false;
3672     if (!A.checkForAllCallLikeInstructions(
3673             CheckCallInst, *this, UsedAssumedInformationInCheckCallInst))
3674       return indicatePessimisticFixpoint();
3675 
3676     // If we haven't used any assumed information for the SPMD state we can fix
3677     // it.
3678     if (!UsedAssumedInformationInCheckRWInst &&
3679         !UsedAssumedInformationInCheckCallInst && AllSPMDStatesWereFixed)
3680       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3681 
3682     return StateBefore == getState() ? ChangeStatus::UNCHANGED
3683                                      : ChangeStatus::CHANGED;
3684   }
3685 
3686 private:
3687   /// Update info regarding reaching kernels.
3688   void updateReachingKernelEntries(Attributor &A) {
3689     auto PredCallSite = [&](AbstractCallSite ACS) {
3690       Function *Caller = ACS.getInstruction()->getFunction();
3691 
3692       assert(Caller && "Caller is nullptr");
3693 
3694       auto &CAA = A.getOrCreateAAFor<AAKernelInfo>(
3695           IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
3696       if (CAA.ReachingKernelEntries.isValidState()) {
3697         ReachingKernelEntries ^= CAA.ReachingKernelEntries;
3698         return true;
3699       }
3700 
3701       // We lost track of the caller of the associated function, any kernel
3702       // could reach now.
3703       ReachingKernelEntries.indicatePessimisticFixpoint();
3704 
3705       return true;
3706     };
3707 
3708     bool AllCallSitesKnown;
3709     if (!A.checkForAllCallSites(PredCallSite, *this,
3710                                 true /* RequireAllCallSites */,
3711                                 AllCallSitesKnown))
3712       ReachingKernelEntries.indicatePessimisticFixpoint();
3713   }
3714 
3715   /// Update info regarding parallel levels.
3716   void updateParallelLevels(Attributor &A) {
3717     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3718     OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
3719         OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
3720 
3721     auto PredCallSite = [&](AbstractCallSite ACS) {
3722       Function *Caller = ACS.getInstruction()->getFunction();
3723 
3724       assert(Caller && "Caller is nullptr");
3725 
3726       auto &CAA =
3727           A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
3728       if (CAA.ParallelLevels.isValidState()) {
3729         // Any function that is called by `__kmpc_parallel_51` will not be
3730         // folded as the parallel level in the function is updated. In order to
3731         // get it right, all the analysis would depend on the implentation. That
3732         // said, if in the future any change to the implementation, the analysis
3733         // could be wrong. As a consequence, we are just conservative here.
3734         if (Caller == Parallel51RFI.Declaration) {
3735           ParallelLevels.indicatePessimisticFixpoint();
3736           return true;
3737         }
3738 
3739         ParallelLevels ^= CAA.ParallelLevels;
3740 
3741         return true;
3742       }
3743 
3744       // We lost track of the caller of the associated function, any kernel
3745       // could reach now.
3746       ParallelLevels.indicatePessimisticFixpoint();
3747 
3748       return true;
3749     };
3750 
3751     bool AllCallSitesKnown = true;
3752     if (!A.checkForAllCallSites(PredCallSite, *this,
3753                                 true /* RequireAllCallSites */,
3754                                 AllCallSitesKnown))
3755       ParallelLevels.indicatePessimisticFixpoint();
3756   }
3757 };
3758 
3759 /// The call site kernel info abstract attribute, basically, what can we say
3760 /// about a call site with regards to the KernelInfoState. For now this simply
3761 /// forwards the information from the callee.
3762 struct AAKernelInfoCallSite : AAKernelInfo {
3763   AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
3764       : AAKernelInfo(IRP, A) {}
3765 
3766   /// See AbstractAttribute::initialize(...).
3767   void initialize(Attributor &A) override {
3768     AAKernelInfo::initialize(A);
3769 
3770     CallBase &CB = cast<CallBase>(getAssociatedValue());
3771     Function *Callee = getAssociatedFunction();
3772 
3773     // Helper to lookup an assumption string.
3774     auto HasAssumption = [](CallBase &CB, StringRef AssumptionStr) {
3775       return hasAssumption(CB, AssumptionStr);
3776     };
3777 
3778     // Check for SPMD-mode assumptions.
3779     if (HasAssumption(CB, "ompx_spmd_amenable")) {
3780       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3781       indicateOptimisticFixpoint();
3782     }
3783 
3784     // First weed out calls we do not care about, that is readonly/readnone
3785     // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
3786     // parallel region or anything else we are looking for.
3787     if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
3788       indicateOptimisticFixpoint();
3789       return;
3790     }
3791 
3792     // Next we check if we know the callee. If it is a known OpenMP function
3793     // we will handle them explicitly in the switch below. If it is not, we
3794     // will use an AAKernelInfo object on the callee to gather information and
3795     // merge that into the current state. The latter happens in the updateImpl.
3796     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3797     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
3798     if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
3799       // Unknown caller or declarations are not analyzable, we give up.
3800       if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
3801 
3802         // Unknown callees might contain parallel regions, except if they have
3803         // an appropriate assumption attached.
3804         if (!(HasAssumption(CB, "omp_no_openmp") ||
3805               HasAssumption(CB, "omp_no_parallelism")))
3806           ReachedUnknownParallelRegions.insert(&CB);
3807 
3808         // If SPMDCompatibilityTracker is not fixed, we need to give up on the
3809         // idea we can run something unknown in SPMD-mode.
3810         if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3811           SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3812           SPMDCompatibilityTracker.insert(&CB);
3813         }
3814 
3815         // We have updated the state for this unknown call properly, there won't
3816         // be any change so we indicate a fixpoint.
3817         indicateOptimisticFixpoint();
3818       }
3819       // If the callee is known and can be used in IPO, we will update the state
3820       // based on the callee state in updateImpl.
3821       return;
3822     }
3823 
3824     const unsigned int WrapperFunctionArgNo = 6;
3825     RuntimeFunction RF = It->getSecond();
3826     switch (RF) {
3827     // All the functions we know are compatible with SPMD mode.
3828     case OMPRTL___kmpc_is_spmd_exec_mode:
3829     case OMPRTL___kmpc_for_static_fini:
3830     case OMPRTL___kmpc_global_thread_num:
3831     case OMPRTL___kmpc_get_hardware_num_threads_in_block:
3832     case OMPRTL___kmpc_get_hardware_num_blocks:
3833     case OMPRTL___kmpc_single:
3834     case OMPRTL___kmpc_end_single:
3835     case OMPRTL___kmpc_master:
3836     case OMPRTL___kmpc_end_master:
3837     case OMPRTL___kmpc_barrier:
3838       break;
3839     case OMPRTL___kmpc_for_static_init_4:
3840     case OMPRTL___kmpc_for_static_init_4u:
3841     case OMPRTL___kmpc_for_static_init_8:
3842     case OMPRTL___kmpc_for_static_init_8u: {
3843       // Check the schedule and allow static schedule in SPMD mode.
3844       unsigned ScheduleArgOpNo = 2;
3845       auto *ScheduleTypeCI =
3846           dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
3847       unsigned ScheduleTypeVal =
3848           ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
3849       switch (OMPScheduleType(ScheduleTypeVal)) {
3850       case OMPScheduleType::Static:
3851       case OMPScheduleType::StaticChunked:
3852       case OMPScheduleType::Distribute:
3853       case OMPScheduleType::DistributeChunked:
3854         break;
3855       default:
3856         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3857         SPMDCompatibilityTracker.insert(&CB);
3858         break;
3859       };
3860     } break;
3861     case OMPRTL___kmpc_target_init:
3862       KernelInitCB = &CB;
3863       break;
3864     case OMPRTL___kmpc_target_deinit:
3865       KernelDeinitCB = &CB;
3866       break;
3867     case OMPRTL___kmpc_parallel_51:
3868       if (auto *ParallelRegion = dyn_cast<Function>(
3869               CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) {
3870         ReachedKnownParallelRegions.insert(ParallelRegion);
3871         break;
3872       }
3873       // The condition above should usually get the parallel region function
3874       // pointer and record it. In the off chance it doesn't we assume the
3875       // worst.
3876       ReachedUnknownParallelRegions.insert(&CB);
3877       break;
3878     case OMPRTL___kmpc_omp_task:
3879       // We do not look into tasks right now, just give up.
3880       SPMDCompatibilityTracker.insert(&CB);
3881       ReachedUnknownParallelRegions.insert(&CB);
3882       indicatePessimisticFixpoint();
3883       return;
3884     case OMPRTL___kmpc_alloc_shared:
3885     case OMPRTL___kmpc_free_shared:
3886       // Return without setting a fixpoint, to be resolved in updateImpl.
3887       return;
3888     default:
3889       // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
3890       // generally.
3891       SPMDCompatibilityTracker.insert(&CB);
3892       indicatePessimisticFixpoint();
3893       return;
3894     }
3895     // All other OpenMP runtime calls will not reach parallel regions so they
3896     // can be safely ignored for now. Since it is a known OpenMP runtime call we
3897     // have now modeled all effects and there is no need for any update.
3898     indicateOptimisticFixpoint();
3899   }
3900 
3901   ChangeStatus updateImpl(Attributor &A) override {
3902     // TODO: Once we have call site specific value information we can provide
3903     //       call site specific liveness information and then it makes
3904     //       sense to specialize attributes for call sites arguments instead of
3905     //       redirecting requests to the callee argument.
3906     Function *F = getAssociatedFunction();
3907 
3908     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3909     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
3910 
3911     // If F is not a runtime function, propagate the AAKernelInfo of the callee.
3912     if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
3913       const IRPosition &FnPos = IRPosition::function(*F);
3914       auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
3915       if (getState() == FnAA.getState())
3916         return ChangeStatus::UNCHANGED;
3917       getState() = FnAA.getState();
3918       return ChangeStatus::CHANGED;
3919     }
3920 
3921     // F is a runtime function that allocates or frees memory, check
3922     // AAHeapToStack and AAHeapToShared.
3923     KernelInfoState StateBefore = getState();
3924     assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
3925             It->getSecond() == OMPRTL___kmpc_free_shared) &&
3926            "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
3927 
3928     CallBase &CB = cast<CallBase>(getAssociatedValue());
3929 
3930     auto &HeapToStackAA = A.getAAFor<AAHeapToStack>(
3931         *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
3932     auto &HeapToSharedAA = A.getAAFor<AAHeapToShared>(
3933         *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
3934 
3935     RuntimeFunction RF = It->getSecond();
3936 
3937     switch (RF) {
3938     // If neither HeapToStack nor HeapToShared assume the call is removed,
3939     // assume SPMD incompatibility.
3940     case OMPRTL___kmpc_alloc_shared:
3941       if (!HeapToStackAA.isAssumedHeapToStack(CB) &&
3942           !HeapToSharedAA.isAssumedHeapToShared(CB))
3943         SPMDCompatibilityTracker.insert(&CB);
3944       break;
3945     case OMPRTL___kmpc_free_shared:
3946       if (!HeapToStackAA.isAssumedHeapToStackRemovedFree(CB) &&
3947           !HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB))
3948         SPMDCompatibilityTracker.insert(&CB);
3949       break;
3950     default:
3951       SPMDCompatibilityTracker.insert(&CB);
3952     }
3953 
3954     return StateBefore == getState() ? ChangeStatus::UNCHANGED
3955                                      : ChangeStatus::CHANGED;
3956   }
3957 };
3958 
3959 struct AAFoldRuntimeCall
3960     : public StateWrapper<BooleanState, AbstractAttribute> {
3961   using Base = StateWrapper<BooleanState, AbstractAttribute>;
3962 
3963   AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3964 
3965   /// Statistics are tracked as part of manifest for now.
3966   void trackStatistics() const override {}
3967 
3968   /// Create an abstract attribute biew for the position \p IRP.
3969   static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
3970                                               Attributor &A);
3971 
3972   /// See AbstractAttribute::getName()
3973   const std::string getName() const override { return "AAFoldRuntimeCall"; }
3974 
3975   /// See AbstractAttribute::getIdAddr()
3976   const char *getIdAddr() const override { return &ID; }
3977 
3978   /// This function should return true if the type of the \p AA is
3979   /// AAFoldRuntimeCall
3980   static bool classof(const AbstractAttribute *AA) {
3981     return (AA->getIdAddr() == &ID);
3982   }
3983 
3984   static const char ID;
3985 };
3986 
3987 struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
3988   AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
3989       : AAFoldRuntimeCall(IRP, A) {}
3990 
3991   /// See AbstractAttribute::getAsStr()
3992   const std::string getAsStr() const override {
3993     if (!isValidState())
3994       return "<invalid>";
3995 
3996     std::string Str("simplified value: ");
3997 
3998     if (!SimplifiedValue.hasValue())
3999       return Str + std::string("none");
4000 
4001     if (!SimplifiedValue.getValue())
4002       return Str + std::string("nullptr");
4003 
4004     if (ConstantInt *CI = dyn_cast<ConstantInt>(SimplifiedValue.getValue()))
4005       return Str + std::to_string(CI->getSExtValue());
4006 
4007     return Str + std::string("unknown");
4008   }
4009 
4010   void initialize(Attributor &A) override {
4011     if (DisableOpenMPOptFolding)
4012       indicatePessimisticFixpoint();
4013 
4014     Function *Callee = getAssociatedFunction();
4015 
4016     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4017     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4018     assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
4019            "Expected a known OpenMP runtime function");
4020 
4021     RFKind = It->getSecond();
4022 
4023     CallBase &CB = cast<CallBase>(getAssociatedValue());
4024     A.registerSimplificationCallback(
4025         IRPosition::callsite_returned(CB),
4026         [&](const IRPosition &IRP, const AbstractAttribute *AA,
4027             bool &UsedAssumedInformation) -> Optional<Value *> {
4028           assert((isValidState() || (SimplifiedValue.hasValue() &&
4029                                      SimplifiedValue.getValue() == nullptr)) &&
4030                  "Unexpected invalid state!");
4031 
4032           if (!isAtFixpoint()) {
4033             UsedAssumedInformation = true;
4034             if (AA)
4035               A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
4036           }
4037           return SimplifiedValue;
4038         });
4039   }
4040 
4041   ChangeStatus updateImpl(Attributor &A) override {
4042     ChangeStatus Changed = ChangeStatus::UNCHANGED;
4043     switch (RFKind) {
4044     case OMPRTL___kmpc_is_spmd_exec_mode:
4045       Changed |= foldIsSPMDExecMode(A);
4046       break;
4047     case OMPRTL___kmpc_is_generic_main_thread_id:
4048       Changed |= foldIsGenericMainThread(A);
4049       break;
4050     case OMPRTL___kmpc_parallel_level:
4051       Changed |= foldParallelLevel(A);
4052       break;
4053     case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4054       Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
4055       break;
4056     case OMPRTL___kmpc_get_hardware_num_blocks:
4057       Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
4058       break;
4059     default:
4060       llvm_unreachable("Unhandled OpenMP runtime function!");
4061     }
4062 
4063     return Changed;
4064   }
4065 
4066   ChangeStatus manifest(Attributor &A) override {
4067     ChangeStatus Changed = ChangeStatus::UNCHANGED;
4068 
4069     if (SimplifiedValue.hasValue() && SimplifiedValue.getValue()) {
4070       Instruction &I = *getCtxI();
4071       A.changeValueAfterManifest(I, **SimplifiedValue);
4072       A.deleteAfterManifest(I);
4073 
4074       CallBase *CB = dyn_cast<CallBase>(&I);
4075       auto Remark = [&](OptimizationRemark OR) {
4076         if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
4077           return OR << "Replacing OpenMP runtime call "
4078                     << CB->getCalledFunction()->getName() << " with "
4079                     << ore::NV("FoldedValue", C->getZExtValue()) << ".";
4080         else
4081           return OR << "Replacing OpenMP runtime call "
4082                     << CB->getCalledFunction()->getName() << ".";
4083       };
4084 
4085       if (CB && EnableVerboseRemarks)
4086         A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
4087 
4088       LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
4089                         << **SimplifiedValue << "\n");
4090 
4091       Changed = ChangeStatus::CHANGED;
4092     }
4093 
4094     return Changed;
4095   }
4096 
4097   ChangeStatus indicatePessimisticFixpoint() override {
4098     SimplifiedValue = nullptr;
4099     return AAFoldRuntimeCall::indicatePessimisticFixpoint();
4100   }
4101 
4102 private:
4103   /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
4104   ChangeStatus foldIsSPMDExecMode(Attributor &A) {
4105     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4106 
4107     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4108     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4109     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4110         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4111 
4112     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4113       return indicatePessimisticFixpoint();
4114 
4115     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4116       auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4117                                           DepClassTy::REQUIRED);
4118 
4119       if (!AA.isValidState()) {
4120         SimplifiedValue = nullptr;
4121         return indicatePessimisticFixpoint();
4122       }
4123 
4124       if (AA.SPMDCompatibilityTracker.isAssumed()) {
4125         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4126           ++KnownSPMDCount;
4127         else
4128           ++AssumedSPMDCount;
4129       } else {
4130         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4131           ++KnownNonSPMDCount;
4132         else
4133           ++AssumedNonSPMDCount;
4134       }
4135     }
4136 
4137     if ((AssumedSPMDCount + KnownSPMDCount) &&
4138         (AssumedNonSPMDCount + KnownNonSPMDCount))
4139       return indicatePessimisticFixpoint();
4140 
4141     auto &Ctx = getAnchorValue().getContext();
4142     if (KnownSPMDCount || AssumedSPMDCount) {
4143       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4144              "Expected only SPMD kernels!");
4145       // All reaching kernels are in SPMD mode. Update all function calls to
4146       // __kmpc_is_spmd_exec_mode to 1.
4147       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4148     } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
4149       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4150              "Expected only non-SPMD kernels!");
4151       // All reaching kernels are in non-SPMD mode. Update all function
4152       // calls to __kmpc_is_spmd_exec_mode to 0.
4153       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
4154     } else {
4155       // We have empty reaching kernels, therefore we cannot tell if the
4156       // associated call site can be folded. At this moment, SimplifiedValue
4157       // must be none.
4158       assert(!SimplifiedValue.hasValue() && "SimplifiedValue should be none");
4159     }
4160 
4161     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4162                                                     : ChangeStatus::CHANGED;
4163   }
4164 
4165   /// Fold __kmpc_is_generic_main_thread_id into a constant if possible.
4166   ChangeStatus foldIsGenericMainThread(Attributor &A) {
4167     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4168 
4169     CallBase &CB = cast<CallBase>(getAssociatedValue());
4170     Function *F = CB.getFunction();
4171     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
4172         *this, IRPosition::function(*F), DepClassTy::REQUIRED);
4173 
4174     if (!ExecutionDomainAA.isValidState())
4175       return indicatePessimisticFixpoint();
4176 
4177     auto &Ctx = getAnchorValue().getContext();
4178     if (ExecutionDomainAA.isExecutedByInitialThreadOnly(CB))
4179       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4180     else
4181       return indicatePessimisticFixpoint();
4182 
4183     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4184                                                     : ChangeStatus::CHANGED;
4185   }
4186 
4187   /// Fold __kmpc_parallel_level into a constant if possible.
4188   ChangeStatus foldParallelLevel(Attributor &A) {
4189     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4190 
4191     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4192         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4193 
4194     if (!CallerKernelInfoAA.ParallelLevels.isValidState())
4195       return indicatePessimisticFixpoint();
4196 
4197     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4198       return indicatePessimisticFixpoint();
4199 
4200     if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
4201       assert(!SimplifiedValue.hasValue() &&
4202              "SimplifiedValue should keep none at this point");
4203       return ChangeStatus::UNCHANGED;
4204     }
4205 
4206     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4207     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4208     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4209       auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4210                                           DepClassTy::REQUIRED);
4211       if (!AA.SPMDCompatibilityTracker.isValidState())
4212         return indicatePessimisticFixpoint();
4213 
4214       if (AA.SPMDCompatibilityTracker.isAssumed()) {
4215         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4216           ++KnownSPMDCount;
4217         else
4218           ++AssumedSPMDCount;
4219       } else {
4220         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4221           ++KnownNonSPMDCount;
4222         else
4223           ++AssumedNonSPMDCount;
4224       }
4225     }
4226 
4227     if ((AssumedSPMDCount + KnownSPMDCount) &&
4228         (AssumedNonSPMDCount + KnownNonSPMDCount))
4229       return indicatePessimisticFixpoint();
4230 
4231     auto &Ctx = getAnchorValue().getContext();
4232     // If the caller can only be reached by SPMD kernel entries, the parallel
4233     // level is 1. Similarly, if the caller can only be reached by non-SPMD
4234     // kernel entries, it is 0.
4235     if (AssumedSPMDCount || KnownSPMDCount) {
4236       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4237              "Expected only SPMD kernels!");
4238       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
4239     } else {
4240       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4241              "Expected only non-SPMD kernels!");
4242       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
4243     }
4244     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4245                                                     : ChangeStatus::CHANGED;
4246   }
4247 
4248   ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
4249     // Specialize only if all the calls agree with the attribute constant value
4250     int32_t CurrentAttrValue = -1;
4251     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4252 
4253     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4254         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4255 
4256     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4257       return indicatePessimisticFixpoint();
4258 
4259     // Iterate over the kernels that reach this function
4260     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4261       int32_t NextAttrVal = -1;
4262       if (K->hasFnAttribute(Attr))
4263         NextAttrVal =
4264             std::stoi(K->getFnAttribute(Attr).getValueAsString().str());
4265 
4266       if (NextAttrVal == -1 ||
4267           (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
4268         return indicatePessimisticFixpoint();
4269       CurrentAttrValue = NextAttrVal;
4270     }
4271 
4272     if (CurrentAttrValue != -1) {
4273       auto &Ctx = getAnchorValue().getContext();
4274       SimplifiedValue =
4275           ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
4276     }
4277     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4278                                                     : ChangeStatus::CHANGED;
4279   }
4280 
4281   /// An optional value the associated value is assumed to fold to. That is, we
4282   /// assume the associated value (which is a call) can be replaced by this
4283   /// simplified value.
4284   Optional<Value *> SimplifiedValue;
4285 
4286   /// The runtime function kind of the callee of the associated call site.
4287   RuntimeFunction RFKind;
4288 };
4289 
4290 } // namespace
4291 
4292 /// Register folding callsite
4293 void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
4294   auto &RFI = OMPInfoCache.RFIs[RF];
4295   RFI.foreachUse(SCC, [&](Use &U, Function &F) {
4296     CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
4297     if (!CI)
4298       return false;
4299     A.getOrCreateAAFor<AAFoldRuntimeCall>(
4300         IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
4301         DepClassTy::NONE, /* ForceUpdate */ false,
4302         /* UpdateAfterInit */ false);
4303     return false;
4304   });
4305 }
4306 
4307 void OpenMPOpt::registerAAs(bool IsModulePass) {
4308   if (SCC.empty())
4309 
4310     return;
4311   if (IsModulePass) {
4312     // Ensure we create the AAKernelInfo AAs first and without triggering an
4313     // update. This will make sure we register all value simplification
4314     // callbacks before any other AA has the chance to create an AAValueSimplify
4315     // or similar.
4316     for (Function *Kernel : OMPInfoCache.Kernels)
4317       A.getOrCreateAAFor<AAKernelInfo>(
4318           IRPosition::function(*Kernel), /* QueryingAA */ nullptr,
4319           DepClassTy::NONE, /* ForceUpdate */ false,
4320           /* UpdateAfterInit */ false);
4321 
4322     registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id);
4323     registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
4324     registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
4325     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
4326     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
4327   }
4328 
4329   // Create CallSite AA for all Getters.
4330   for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
4331     auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
4332 
4333     auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
4334 
4335     auto CreateAA = [&](Use &U, Function &Caller) {
4336       CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
4337       if (!CI)
4338         return false;
4339 
4340       auto &CB = cast<CallBase>(*CI);
4341 
4342       IRPosition CBPos = IRPosition::callsite_function(CB);
4343       A.getOrCreateAAFor<AAICVTracker>(CBPos);
4344       return false;
4345     };
4346 
4347     GetterRFI.foreachUse(SCC, CreateAA);
4348   }
4349   auto &GlobalizationRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4350   auto CreateAA = [&](Use &U, Function &F) {
4351     A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
4352     return false;
4353   };
4354   if (!DisableOpenMPOptDeglobalization)
4355     GlobalizationRFI.foreachUse(SCC, CreateAA);
4356 
4357   // Create an ExecutionDomain AA for every function and a HeapToStack AA for
4358   // every function if there is a device kernel.
4359   if (!isOpenMPDevice(M))
4360     return;
4361 
4362   for (auto *F : SCC) {
4363     if (F->isDeclaration())
4364       continue;
4365 
4366     A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F));
4367     if (!DisableOpenMPOptDeglobalization)
4368       A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F));
4369 
4370     for (auto &I : instructions(*F)) {
4371       if (auto *LI = dyn_cast<LoadInst>(&I)) {
4372         bool UsedAssumedInformation = false;
4373         A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
4374                                UsedAssumedInformation);
4375       }
4376     }
4377   }
4378 }
4379 
4380 const char AAICVTracker::ID = 0;
4381 const char AAKernelInfo::ID = 0;
4382 const char AAExecutionDomain::ID = 0;
4383 const char AAHeapToShared::ID = 0;
4384 const char AAFoldRuntimeCall::ID = 0;
4385 
4386 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
4387                                               Attributor &A) {
4388   AAICVTracker *AA = nullptr;
4389   switch (IRP.getPositionKind()) {
4390   case IRPosition::IRP_INVALID:
4391   case IRPosition::IRP_FLOAT:
4392   case IRPosition::IRP_ARGUMENT:
4393   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4394     llvm_unreachable("ICVTracker can only be created for function position!");
4395   case IRPosition::IRP_RETURNED:
4396     AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
4397     break;
4398   case IRPosition::IRP_CALL_SITE_RETURNED:
4399     AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
4400     break;
4401   case IRPosition::IRP_CALL_SITE:
4402     AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
4403     break;
4404   case IRPosition::IRP_FUNCTION:
4405     AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
4406     break;
4407   }
4408 
4409   return *AA;
4410 }
4411 
4412 AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP,
4413                                                         Attributor &A) {
4414   AAExecutionDomainFunction *AA = nullptr;
4415   switch (IRP.getPositionKind()) {
4416   case IRPosition::IRP_INVALID:
4417   case IRPosition::IRP_FLOAT:
4418   case IRPosition::IRP_ARGUMENT:
4419   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4420   case IRPosition::IRP_RETURNED:
4421   case IRPosition::IRP_CALL_SITE_RETURNED:
4422   case IRPosition::IRP_CALL_SITE:
4423     llvm_unreachable(
4424         "AAExecutionDomain can only be created for function position!");
4425   case IRPosition::IRP_FUNCTION:
4426     AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
4427     break;
4428   }
4429 
4430   return *AA;
4431 }
4432 
4433 AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
4434                                                   Attributor &A) {
4435   AAHeapToSharedFunction *AA = nullptr;
4436   switch (IRP.getPositionKind()) {
4437   case IRPosition::IRP_INVALID:
4438   case IRPosition::IRP_FLOAT:
4439   case IRPosition::IRP_ARGUMENT:
4440   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4441   case IRPosition::IRP_RETURNED:
4442   case IRPosition::IRP_CALL_SITE_RETURNED:
4443   case IRPosition::IRP_CALL_SITE:
4444     llvm_unreachable(
4445         "AAHeapToShared can only be created for function position!");
4446   case IRPosition::IRP_FUNCTION:
4447     AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
4448     break;
4449   }
4450 
4451   return *AA;
4452 }
4453 
4454 AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
4455                                               Attributor &A) {
4456   AAKernelInfo *AA = nullptr;
4457   switch (IRP.getPositionKind()) {
4458   case IRPosition::IRP_INVALID:
4459   case IRPosition::IRP_FLOAT:
4460   case IRPosition::IRP_ARGUMENT:
4461   case IRPosition::IRP_RETURNED:
4462   case IRPosition::IRP_CALL_SITE_RETURNED:
4463   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4464     llvm_unreachable("KernelInfo can only be created for function position!");
4465   case IRPosition::IRP_CALL_SITE:
4466     AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
4467     break;
4468   case IRPosition::IRP_FUNCTION:
4469     AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
4470     break;
4471   }
4472 
4473   return *AA;
4474 }
4475 
4476 AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
4477                                                         Attributor &A) {
4478   AAFoldRuntimeCall *AA = nullptr;
4479   switch (IRP.getPositionKind()) {
4480   case IRPosition::IRP_INVALID:
4481   case IRPosition::IRP_FLOAT:
4482   case IRPosition::IRP_ARGUMENT:
4483   case IRPosition::IRP_RETURNED:
4484   case IRPosition::IRP_FUNCTION:
4485   case IRPosition::IRP_CALL_SITE:
4486   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4487     llvm_unreachable("KernelInfo can only be created for call site position!");
4488   case IRPosition::IRP_CALL_SITE_RETURNED:
4489     AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
4490     break;
4491   }
4492 
4493   return *AA;
4494 }
4495 
4496 PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
4497   if (!containsOpenMP(M))
4498     return PreservedAnalyses::all();
4499   if (DisableOpenMPOptimizations)
4500     return PreservedAnalyses::all();
4501 
4502   FunctionAnalysisManager &FAM =
4503       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
4504   KernelSet Kernels = getDeviceKernels(M);
4505 
4506   auto IsCalled = [&](Function &F) {
4507     if (Kernels.contains(&F))
4508       return true;
4509     for (const User *U : F.users())
4510       if (!isa<BlockAddress>(U))
4511         return true;
4512     return false;
4513   };
4514 
4515   auto EmitRemark = [&](Function &F) {
4516     auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
4517     ORE.emit([&]() {
4518       OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
4519       return ORA << "Could not internalize function. "
4520                  << "Some optimizations may not be possible. [OMP140]";
4521     });
4522   };
4523 
4524   // Create internal copies of each function if this is a kernel Module. This
4525   // allows iterprocedural passes to see every call edge.
4526   DenseMap<Function *, Function *> InternalizedMap;
4527   if (isOpenMPDevice(M)) {
4528     SmallPtrSet<Function *, 16> InternalizeFns;
4529     for (Function &F : M)
4530       if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
4531           !DisableInternalization) {
4532         if (Attributor::isInternalizable(F)) {
4533           InternalizeFns.insert(&F);
4534         } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
4535           EmitRemark(F);
4536         }
4537       }
4538 
4539     Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
4540   }
4541 
4542   // Look at every function in the Module unless it was internalized.
4543   SmallVector<Function *, 16> SCC;
4544   for (Function &F : M)
4545     if (!F.isDeclaration() && !InternalizedMap.lookup(&F))
4546       SCC.push_back(&F);
4547 
4548   if (SCC.empty())
4549     return PreservedAnalyses::all();
4550 
4551   AnalysisGetter AG(FAM);
4552 
4553   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
4554     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
4555   };
4556 
4557   BumpPtrAllocator Allocator;
4558   CallGraphUpdater CGUpdater;
4559 
4560   SetVector<Function *> Functions(SCC.begin(), SCC.end());
4561   OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels);
4562 
4563   unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4564   Attributor A(Functions, InfoCache, CGUpdater, nullptr, true, false,
4565                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4566 
4567   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4568   bool Changed = OMPOpt.run(true);
4569 
4570   // Optionally inline device functions for potentially better performance.
4571   if (AlwaysInlineDeviceFunctions && isOpenMPDevice(M))
4572     for (Function &F : M)
4573       if (!F.isDeclaration() && !Kernels.contains(&F) &&
4574           !F.hasFnAttribute(Attribute::NoInline))
4575         F.addFnAttr(Attribute::AlwaysInline);
4576 
4577   if (PrintModuleAfterOptimizations)
4578     LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
4579 
4580   if (Changed)
4581     return PreservedAnalyses::none();
4582 
4583   return PreservedAnalyses::all();
4584 }
4585 
4586 PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
4587                                           CGSCCAnalysisManager &AM,
4588                                           LazyCallGraph &CG,
4589                                           CGSCCUpdateResult &UR) {
4590   if (!containsOpenMP(*C.begin()->getFunction().getParent()))
4591     return PreservedAnalyses::all();
4592   if (DisableOpenMPOptimizations)
4593     return PreservedAnalyses::all();
4594 
4595   SmallVector<Function *, 16> SCC;
4596   // If there are kernels in the module, we have to run on all SCC's.
4597   for (LazyCallGraph::Node &N : C) {
4598     Function *Fn = &N.getFunction();
4599     SCC.push_back(Fn);
4600   }
4601 
4602   if (SCC.empty())
4603     return PreservedAnalyses::all();
4604 
4605   Module &M = *C.begin()->getFunction().getParent();
4606 
4607   KernelSet Kernels = getDeviceKernels(M);
4608 
4609   FunctionAnalysisManager &FAM =
4610       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
4611 
4612   AnalysisGetter AG(FAM);
4613 
4614   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
4615     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
4616   };
4617 
4618   BumpPtrAllocator Allocator;
4619   CallGraphUpdater CGUpdater;
4620   CGUpdater.initialize(CG, C, AM, UR);
4621 
4622   SetVector<Function *> Functions(SCC.begin(), SCC.end());
4623   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
4624                                 /*CGSCC*/ Functions, Kernels);
4625 
4626   unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4627   Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
4628                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4629 
4630   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4631   bool Changed = OMPOpt.run(false);
4632 
4633   if (PrintModuleAfterOptimizations)
4634     LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
4635 
4636   if (Changed)
4637     return PreservedAnalyses::none();
4638 
4639   return PreservedAnalyses::all();
4640 }
4641 
4642 namespace {
4643 
4644 struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass {
4645   CallGraphUpdater CGUpdater;
4646   static char ID;
4647 
4648   OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) {
4649     initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry());
4650   }
4651 
4652   void getAnalysisUsage(AnalysisUsage &AU) const override {
4653     CallGraphSCCPass::getAnalysisUsage(AU);
4654   }
4655 
4656   bool runOnSCC(CallGraphSCC &CGSCC) override {
4657     if (!containsOpenMP(CGSCC.getCallGraph().getModule()))
4658       return false;
4659     if (DisableOpenMPOptimizations || skipSCC(CGSCC))
4660       return false;
4661 
4662     SmallVector<Function *, 16> SCC;
4663     // If there are kernels in the module, we have to run on all SCC's.
4664     for (CallGraphNode *CGN : CGSCC) {
4665       Function *Fn = CGN->getFunction();
4666       if (!Fn || Fn->isDeclaration())
4667         continue;
4668       SCC.push_back(Fn);
4669     }
4670 
4671     if (SCC.empty())
4672       return false;
4673 
4674     Module &M = CGSCC.getCallGraph().getModule();
4675     KernelSet Kernels = getDeviceKernels(M);
4676 
4677     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
4678     CGUpdater.initialize(CG, CGSCC);
4679 
4680     // Maintain a map of functions to avoid rebuilding the ORE
4681     DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap;
4682     auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
4683       std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
4684       if (!ORE)
4685         ORE = std::make_unique<OptimizationRemarkEmitter>(F);
4686       return *ORE;
4687     };
4688 
4689     AnalysisGetter AG;
4690     SetVector<Function *> Functions(SCC.begin(), SCC.end());
4691     BumpPtrAllocator Allocator;
4692     OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG,
4693                                   Allocator,
4694                                   /*CGSCC*/ Functions, Kernels);
4695 
4696     unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4697     Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
4698                  MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4699 
4700     OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4701     bool Result = OMPOpt.run(false);
4702 
4703     if (PrintModuleAfterOptimizations)
4704       LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
4705 
4706     return Result;
4707   }
4708 
4709   bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
4710 };
4711 
4712 } // end anonymous namespace
4713 
4714 KernelSet llvm::omp::getDeviceKernels(Module &M) {
4715   // TODO: Create a more cross-platform way of determining device kernels.
4716   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
4717   KernelSet Kernels;
4718 
4719   if (!MD)
4720     return Kernels;
4721 
4722   for (auto *Op : MD->operands()) {
4723     if (Op->getNumOperands() < 2)
4724       continue;
4725     MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
4726     if (!KindID || KindID->getString() != "kernel")
4727       continue;
4728 
4729     Function *KernelFn =
4730         mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
4731     if (!KernelFn)
4732       continue;
4733 
4734     ++NumOpenMPTargetRegionKernels;
4735 
4736     Kernels.insert(KernelFn);
4737   }
4738 
4739   return Kernels;
4740 }
4741 
4742 bool llvm::omp::containsOpenMP(Module &M) {
4743   Metadata *MD = M.getModuleFlag("openmp");
4744   if (!MD)
4745     return false;
4746 
4747   return true;
4748 }
4749 
4750 bool llvm::omp::isOpenMPDevice(Module &M) {
4751   Metadata *MD = M.getModuleFlag("openmp-device");
4752   if (!MD)
4753     return false;
4754 
4755   return true;
4756 }
4757 
4758 char OpenMPOptCGSCCLegacyPass::ID = 0;
4759 
4760 INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
4761                       "OpenMP specific optimizations", false, false)
4762 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
4763 INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
4764                     "OpenMP specific optimizations", false, false)
4765 
4766 Pass *llvm::createOpenMPOptCGSCCLegacyPass() {
4767   return new OpenMPOptCGSCCLegacyPass();
4768 }
4769