1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpenMP specific optimizations:
10 //
11 // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12 // - Replacing globalized device memory with stack memory.
13 // - Replacing globalized device memory with shared memory.
14 // - Parallel region merging.
15 // - Transforming generic-mode device kernels to SPMD mode.
16 // - Specializing the state machine for generic-mode device kernels.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "llvm/Transforms/IPO/OpenMPOpt.h"
21 
22 #include "llvm/ADT/EnumeratedArray.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/Statistic.h"
25 #include "llvm/Analysis/CallGraph.h"
26 #include "llvm/Analysis/CallGraphSCCPass.h"
27 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
28 #include "llvm/Analysis/ValueTracking.h"
29 #include "llvm/Frontend/OpenMP/OMPConstants.h"
30 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
31 #include "llvm/IR/Assumptions.h"
32 #include "llvm/IR/DiagnosticInfo.h"
33 #include "llvm/IR/GlobalValue.h"
34 #include "llvm/IR/Instruction.h"
35 #include "llvm/IR/IntrinsicInst.h"
36 #include "llvm/IR/IntrinsicsAMDGPU.h"
37 #include "llvm/IR/IntrinsicsNVPTX.h"
38 #include "llvm/InitializePasses.h"
39 #include "llvm/Support/CommandLine.h"
40 #include "llvm/Transforms/IPO.h"
41 #include "llvm/Transforms/IPO/Attributor.h"
42 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
43 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
44 #include "llvm/Transforms/Utils/CodeExtractor.h"
45 
46 using namespace llvm;
47 using namespace omp;
48 
49 #define DEBUG_TYPE "openmp-opt"
50 
51 static cl::opt<bool> DisableOpenMPOptimizations(
52     "openmp-opt-disable", cl::ZeroOrMore,
53     cl::desc("Disable OpenMP specific optimizations."), cl::Hidden,
54     cl::init(false));
55 
56 static cl::opt<bool> EnableParallelRegionMerging(
57     "openmp-opt-enable-merging", cl::ZeroOrMore,
58     cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
59     cl::init(false));
60 
61 static cl::opt<bool>
62     DisableInternalization("openmp-opt-disable-internalization", cl::ZeroOrMore,
63                            cl::desc("Disable function internalization."),
64                            cl::Hidden, cl::init(false));
65 
66 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
67                                     cl::Hidden);
68 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
69                                         cl::init(false), cl::Hidden);
70 
71 static cl::opt<bool> HideMemoryTransferLatency(
72     "openmp-hide-memory-transfer-latency",
73     cl::desc("[WIP] Tries to hide the latency of host to device memory"
74              " transfers"),
75     cl::Hidden, cl::init(false));
76 
77 static cl::opt<bool> DisableOpenMPOptDeglobalization(
78     "openmp-opt-disable-deglobalization", cl::ZeroOrMore,
79     cl::desc("Disable OpenMP optimizations involving deglobalization."),
80     cl::Hidden, cl::init(false));
81 
82 static cl::opt<bool> DisableOpenMPOptSPMDization(
83     "openmp-opt-disable-spmdization", cl::ZeroOrMore,
84     cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
85     cl::Hidden, cl::init(false));
86 
87 static cl::opt<bool> DisableOpenMPOptFolding(
88     "openmp-opt-disable-folding", cl::ZeroOrMore,
89     cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
90     cl::init(false));
91 
92 static cl::opt<bool> DisableOpenMPOptStateMachineRewrite(
93     "openmp-opt-disable-state-machine-rewrite", cl::ZeroOrMore,
94     cl::desc("Disable OpenMP optimizations that replace the state machine."),
95     cl::Hidden, cl::init(false));
96 
97 static cl::opt<bool> PrintModuleAfterOptimizations(
98     "openmp-opt-print-module", cl::ZeroOrMore,
99     cl::desc("Print the current module after OpenMP optimizations."),
100     cl::Hidden, cl::init(false));
101 
102 static cl::opt<bool> AlwaysInlineDeviceFunctions(
103     "openmp-opt-inline-device", cl::ZeroOrMore,
104     cl::desc("Inline all applicible functions on the device."), cl::Hidden,
105     cl::init(false));
106 
107 static cl::opt<bool>
108     EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::ZeroOrMore,
109                          cl::desc("Enables more verbose remarks."), cl::Hidden,
110                          cl::init(false));
111 
112 STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
113           "Number of OpenMP runtime calls deduplicated");
114 STATISTIC(NumOpenMPParallelRegionsDeleted,
115           "Number of OpenMP parallel regions deleted");
116 STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
117           "Number of OpenMP runtime functions identified");
118 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
119           "Number of OpenMP runtime function uses identified");
120 STATISTIC(NumOpenMPTargetRegionKernels,
121           "Number of OpenMP target region entry points (=kernels) identified");
122 STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
123           "Number of OpenMP target region entry points (=kernels) executed in "
124           "SPMD-mode instead of generic-mode");
125 STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
126           "Number of OpenMP target region entry points (=kernels) executed in "
127           "generic-mode without a state machines");
128 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
129           "Number of OpenMP target region entry points (=kernels) executed in "
130           "generic-mode with customized state machines with fallback");
131 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
132           "Number of OpenMP target region entry points (=kernels) executed in "
133           "generic-mode with customized state machines without fallback");
134 STATISTIC(
135     NumOpenMPParallelRegionsReplacedInGPUStateMachine,
136     "Number of OpenMP parallel regions replaced with ID in GPU state machines");
137 STATISTIC(NumOpenMPParallelRegionsMerged,
138           "Number of OpenMP parallel regions merged");
139 STATISTIC(NumBytesMovedToSharedMemory,
140           "Amount of memory pushed to shared memory");
141 
142 #if !defined(NDEBUG)
143 static constexpr auto TAG = "[" DEBUG_TYPE "]";
144 #endif
145 
146 namespace {
147 
148 enum class AddressSpace : unsigned {
149   Generic = 0,
150   Global = 1,
151   Shared = 3,
152   Constant = 4,
153   Local = 5,
154 };
155 
156 struct AAHeapToShared;
157 
158 struct AAICVTracker;
159 
160 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
161 /// Attributor runs.
162 struct OMPInformationCache : public InformationCache {
163   OMPInformationCache(Module &M, AnalysisGetter &AG,
164                       BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC,
165                       SmallPtrSetImpl<Kernel> &Kernels)
166       : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M),
167         Kernels(Kernels) {
168 
169     OMPBuilder.initialize();
170     initializeRuntimeFunctions();
171     initializeInternalControlVars();
172   }
173 
174   /// Generic information that describes an internal control variable.
175   struct InternalControlVarInfo {
176     /// The kind, as described by InternalControlVar enum.
177     InternalControlVar Kind;
178 
179     /// The name of the ICV.
180     StringRef Name;
181 
182     /// Environment variable associated with this ICV.
183     StringRef EnvVarName;
184 
185     /// Initial value kind.
186     ICVInitValue InitKind;
187 
188     /// Initial value.
189     ConstantInt *InitValue;
190 
191     /// Setter RTL function associated with this ICV.
192     RuntimeFunction Setter;
193 
194     /// Getter RTL function associated with this ICV.
195     RuntimeFunction Getter;
196 
197     /// RTL Function corresponding to the override clause of this ICV
198     RuntimeFunction Clause;
199   };
200 
201   /// Generic information that describes a runtime function
202   struct RuntimeFunctionInfo {
203 
204     /// The kind, as described by the RuntimeFunction enum.
205     RuntimeFunction Kind;
206 
207     /// The name of the function.
208     StringRef Name;
209 
210     /// Flag to indicate a variadic function.
211     bool IsVarArg;
212 
213     /// The return type of the function.
214     Type *ReturnType;
215 
216     /// The argument types of the function.
217     SmallVector<Type *, 8> ArgumentTypes;
218 
219     /// The declaration if available.
220     Function *Declaration = nullptr;
221 
222     /// Uses of this runtime function per function containing the use.
223     using UseVector = SmallVector<Use *, 16>;
224 
225     /// Clear UsesMap for runtime function.
226     void clearUsesMap() { UsesMap.clear(); }
227 
228     /// Boolean conversion that is true if the runtime function was found.
229     operator bool() const { return Declaration; }
230 
231     /// Return the vector of uses in function \p F.
232     UseVector &getOrCreateUseVector(Function *F) {
233       std::shared_ptr<UseVector> &UV = UsesMap[F];
234       if (!UV)
235         UV = std::make_shared<UseVector>();
236       return *UV;
237     }
238 
239     /// Return the vector of uses in function \p F or `nullptr` if there are
240     /// none.
241     const UseVector *getUseVector(Function &F) const {
242       auto I = UsesMap.find(&F);
243       if (I != UsesMap.end())
244         return I->second.get();
245       return nullptr;
246     }
247 
248     /// Return how many functions contain uses of this runtime function.
249     size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
250 
251     /// Return the number of arguments (or the minimal number for variadic
252     /// functions).
253     size_t getNumArgs() const { return ArgumentTypes.size(); }
254 
255     /// Run the callback \p CB on each use and forget the use if the result is
256     /// true. The callback will be fed the function in which the use was
257     /// encountered as second argument.
258     void foreachUse(SmallVectorImpl<Function *> &SCC,
259                     function_ref<bool(Use &, Function &)> CB) {
260       for (Function *F : SCC)
261         foreachUse(CB, F);
262     }
263 
264     /// Run the callback \p CB on each use within the function \p F and forget
265     /// the use if the result is true.
266     void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
267       SmallVector<unsigned, 8> ToBeDeleted;
268       ToBeDeleted.clear();
269 
270       unsigned Idx = 0;
271       UseVector &UV = getOrCreateUseVector(F);
272 
273       for (Use *U : UV) {
274         if (CB(*U, *F))
275           ToBeDeleted.push_back(Idx);
276         ++Idx;
277       }
278 
279       // Remove the to-be-deleted indices in reverse order as prior
280       // modifications will not modify the smaller indices.
281       while (!ToBeDeleted.empty()) {
282         unsigned Idx = ToBeDeleted.pop_back_val();
283         UV[Idx] = UV.back();
284         UV.pop_back();
285       }
286     }
287 
288   private:
289     /// Map from functions to all uses of this runtime function contained in
290     /// them.
291     DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
292 
293   public:
294     /// Iterators for the uses of this runtime function.
295     decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
296     decltype(UsesMap)::iterator end() { return UsesMap.end(); }
297   };
298 
299   /// An OpenMP-IR-Builder instance
300   OpenMPIRBuilder OMPBuilder;
301 
302   /// Map from runtime function kind to the runtime function description.
303   EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
304                   RuntimeFunction::OMPRTL___last>
305       RFIs;
306 
307   /// Map from function declarations/definitions to their runtime enum type.
308   DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
309 
310   /// Map from ICV kind to the ICV description.
311   EnumeratedArray<InternalControlVarInfo, InternalControlVar,
312                   InternalControlVar::ICV___last>
313       ICVs;
314 
315   /// Helper to initialize all internal control variable information for those
316   /// defined in OMPKinds.def.
317   void initializeInternalControlVars() {
318 #define ICV_RT_SET(_Name, RTL)                                                 \
319   {                                                                            \
320     auto &ICV = ICVs[_Name];                                                   \
321     ICV.Setter = RTL;                                                          \
322   }
323 #define ICV_RT_GET(Name, RTL)                                                  \
324   {                                                                            \
325     auto &ICV = ICVs[Name];                                                    \
326     ICV.Getter = RTL;                                                          \
327   }
328 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init)                           \
329   {                                                                            \
330     auto &ICV = ICVs[Enum];                                                    \
331     ICV.Name = _Name;                                                          \
332     ICV.Kind = Enum;                                                           \
333     ICV.InitKind = Init;                                                       \
334     ICV.EnvVarName = _EnvVarName;                                              \
335     switch (ICV.InitKind) {                                                    \
336     case ICV_IMPLEMENTATION_DEFINED:                                           \
337       ICV.InitValue = nullptr;                                                 \
338       break;                                                                   \
339     case ICV_ZERO:                                                             \
340       ICV.InitValue = ConstantInt::get(                                        \
341           Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0);                \
342       break;                                                                   \
343     case ICV_FALSE:                                                            \
344       ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext());    \
345       break;                                                                   \
346     case ICV_LAST:                                                             \
347       break;                                                                   \
348     }                                                                          \
349   }
350 #include "llvm/Frontend/OpenMP/OMPKinds.def"
351   }
352 
353   /// Returns true if the function declaration \p F matches the runtime
354   /// function types, that is, return type \p RTFRetType, and argument types
355   /// \p RTFArgTypes.
356   static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
357                                   SmallVector<Type *, 8> &RTFArgTypes) {
358     // TODO: We should output information to the user (under debug output
359     //       and via remarks).
360 
361     if (!F)
362       return false;
363     if (F->getReturnType() != RTFRetType)
364       return false;
365     if (F->arg_size() != RTFArgTypes.size())
366       return false;
367 
368     auto RTFTyIt = RTFArgTypes.begin();
369     for (Argument &Arg : F->args()) {
370       if (Arg.getType() != *RTFTyIt)
371         return false;
372 
373       ++RTFTyIt;
374     }
375 
376     return true;
377   }
378 
379   // Helper to collect all uses of the declaration in the UsesMap.
380   unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
381     unsigned NumUses = 0;
382     if (!RFI.Declaration)
383       return NumUses;
384     OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
385 
386     if (CollectStats) {
387       NumOpenMPRuntimeFunctionsIdentified += 1;
388       NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
389     }
390 
391     // TODO: We directly convert uses into proper calls and unknown uses.
392     for (Use &U : RFI.Declaration->uses()) {
393       if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
394         if (ModuleSlice.count(UserI->getFunction())) {
395           RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
396           ++NumUses;
397         }
398       } else {
399         RFI.getOrCreateUseVector(nullptr).push_back(&U);
400         ++NumUses;
401       }
402     }
403     return NumUses;
404   }
405 
406   // Helper function to recollect uses of a runtime function.
407   void recollectUsesForFunction(RuntimeFunction RTF) {
408     auto &RFI = RFIs[RTF];
409     RFI.clearUsesMap();
410     collectUses(RFI, /*CollectStats*/ false);
411   }
412 
413   // Helper function to recollect uses of all runtime functions.
414   void recollectUses() {
415     for (int Idx = 0; Idx < RFIs.size(); ++Idx)
416       recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
417   }
418 
419   /// Helper to initialize all runtime function information for those defined
420   /// in OpenMPKinds.def.
421   void initializeRuntimeFunctions() {
422     Module &M = *((*ModuleSlice.begin())->getParent());
423 
424     // Helper macros for handling __VA_ARGS__ in OMP_RTL
425 #define OMP_TYPE(VarName, ...)                                                 \
426   Type *VarName = OMPBuilder.VarName;                                          \
427   (void)VarName;
428 
429 #define OMP_ARRAY_TYPE(VarName, ...)                                           \
430   ArrayType *VarName##Ty = OMPBuilder.VarName##Ty;                             \
431   (void)VarName##Ty;                                                           \
432   PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy;                     \
433   (void)VarName##PtrTy;
434 
435 #define OMP_FUNCTION_TYPE(VarName, ...)                                        \
436   FunctionType *VarName = OMPBuilder.VarName;                                  \
437   (void)VarName;                                                               \
438   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
439   (void)VarName##Ptr;
440 
441 #define OMP_STRUCT_TYPE(VarName, ...)                                          \
442   StructType *VarName = OMPBuilder.VarName;                                    \
443   (void)VarName;                                                               \
444   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
445   (void)VarName##Ptr;
446 
447 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
448   {                                                                            \
449     SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
450     Function *F = M.getFunction(_Name);                                        \
451     RTLFunctions.insert(F);                                                    \
452     if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) {           \
453       RuntimeFunctionIDMap[F] = _Enum;                                         \
454       F->removeFnAttr(Attribute::NoInline);                                    \
455       auto &RFI = RFIs[_Enum];                                                 \
456       RFI.Kind = _Enum;                                                        \
457       RFI.Name = _Name;                                                        \
458       RFI.IsVarArg = _IsVarArg;                                                \
459       RFI.ReturnType = OMPBuilder._ReturnType;                                 \
460       RFI.ArgumentTypes = std::move(ArgsTypes);                                \
461       RFI.Declaration = F;                                                     \
462       unsigned NumUses = collectUses(RFI);                                     \
463       (void)NumUses;                                                           \
464       LLVM_DEBUG({                                                             \
465         dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
466                << " found\n";                                                  \
467         if (RFI.Declaration)                                                   \
468           dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
469                  << RFI.getNumFunctionsWithUses()                              \
470                  << " different functions.\n";                                 \
471       });                                                                      \
472     }                                                                          \
473   }
474 #include "llvm/Frontend/OpenMP/OMPKinds.def"
475 
476     // TODO: We should attach the attributes defined in OMPKinds.def.
477   }
478 
479   /// Collection of known kernels (\see Kernel) in the module.
480   SmallPtrSetImpl<Kernel> &Kernels;
481 
482   /// Collection of known OpenMP runtime functions..
483   DenseSet<const Function *> RTLFunctions;
484 };
485 
486 template <typename Ty, bool InsertInvalidates = true>
487 struct BooleanStateWithSetVector : public BooleanState {
488   bool contains(const Ty &Elem) const { return Set.contains(Elem); }
489   bool insert(const Ty &Elem) {
490     if (InsertInvalidates)
491       BooleanState::indicatePessimisticFixpoint();
492     return Set.insert(Elem);
493   }
494 
495   const Ty &operator[](int Idx) const { return Set[Idx]; }
496   bool operator==(const BooleanStateWithSetVector &RHS) const {
497     return BooleanState::operator==(RHS) && Set == RHS.Set;
498   }
499   bool operator!=(const BooleanStateWithSetVector &RHS) const {
500     return !(*this == RHS);
501   }
502 
503   bool empty() const { return Set.empty(); }
504   size_t size() const { return Set.size(); }
505 
506   /// "Clamp" this state with \p RHS.
507   BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
508     BooleanState::operator^=(RHS);
509     Set.insert(RHS.Set.begin(), RHS.Set.end());
510     return *this;
511   }
512 
513 private:
514   /// A set to keep track of elements.
515   SetVector<Ty> Set;
516 
517 public:
518   typename decltype(Set)::iterator begin() { return Set.begin(); }
519   typename decltype(Set)::iterator end() { return Set.end(); }
520   typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
521   typename decltype(Set)::const_iterator end() const { return Set.end(); }
522 };
523 
524 template <typename Ty, bool InsertInvalidates = true>
525 using BooleanStateWithPtrSetVector =
526     BooleanStateWithSetVector<Ty *, InsertInvalidates>;
527 
528 struct KernelInfoState : AbstractState {
529   /// Flag to track if we reached a fixpoint.
530   bool IsAtFixpoint = false;
531 
532   /// The parallel regions (identified by the outlined parallel functions) that
533   /// can be reached from the associated function.
534   BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false>
535       ReachedKnownParallelRegions;
536 
537   /// State to track what parallel region we might reach.
538   BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
539 
540   /// State to track if we are in SPMD-mode, assumed or know, and why we decided
541   /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
542   /// false.
543   BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
544 
545   /// The __kmpc_target_init call in this kernel, if any. If we find more than
546   /// one we abort as the kernel is malformed.
547   CallBase *KernelInitCB = nullptr;
548 
549   /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
550   /// one we abort as the kernel is malformed.
551   CallBase *KernelDeinitCB = nullptr;
552 
553   /// Flag to indicate if the associated function is a kernel entry.
554   bool IsKernelEntry = false;
555 
556   /// State to track what kernel entries can reach the associated function.
557   BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
558 
559   /// State to indicate if we can track parallel level of the associated
560   /// function. We will give up tracking if we encounter unknown caller or the
561   /// caller is __kmpc_parallel_51.
562   BooleanStateWithSetVector<uint8_t> ParallelLevels;
563 
564   /// Abstract State interface
565   ///{
566 
567   KernelInfoState() {}
568   KernelInfoState(bool BestState) {
569     if (!BestState)
570       indicatePessimisticFixpoint();
571   }
572 
573   /// See AbstractState::isValidState(...)
574   bool isValidState() const override { return true; }
575 
576   /// See AbstractState::isAtFixpoint(...)
577   bool isAtFixpoint() const override { return IsAtFixpoint; }
578 
579   /// See AbstractState::indicatePessimisticFixpoint(...)
580   ChangeStatus indicatePessimisticFixpoint() override {
581     IsAtFixpoint = true;
582     SPMDCompatibilityTracker.indicatePessimisticFixpoint();
583     ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
584     return ChangeStatus::CHANGED;
585   }
586 
587   /// See AbstractState::indicateOptimisticFixpoint(...)
588   ChangeStatus indicateOptimisticFixpoint() override {
589     IsAtFixpoint = true;
590     return ChangeStatus::UNCHANGED;
591   }
592 
593   /// Return the assumed state
594   KernelInfoState &getAssumed() { return *this; }
595   const KernelInfoState &getAssumed() const { return *this; }
596 
597   bool operator==(const KernelInfoState &RHS) const {
598     if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
599       return false;
600     if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
601       return false;
602     if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
603       return false;
604     if (ReachingKernelEntries != RHS.ReachingKernelEntries)
605       return false;
606     return true;
607   }
608 
609   /// Returns true if this kernel contains any OpenMP parallel regions.
610   bool mayContainParallelRegion() {
611     return !ReachedKnownParallelRegions.empty() ||
612            !ReachedUnknownParallelRegions.empty();
613   }
614 
615   /// Return empty set as the best state of potential values.
616   static KernelInfoState getBestState() { return KernelInfoState(true); }
617 
618   static KernelInfoState getBestState(KernelInfoState &KIS) {
619     return getBestState();
620   }
621 
622   /// Return full set as the worst state of potential values.
623   static KernelInfoState getWorstState() { return KernelInfoState(false); }
624 
625   /// "Clamp" this state with \p KIS.
626   KernelInfoState operator^=(const KernelInfoState &KIS) {
627     // Do not merge two different _init and _deinit call sites.
628     if (KIS.KernelInitCB) {
629       if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
630         indicatePessimisticFixpoint();
631       KernelInitCB = KIS.KernelInitCB;
632     }
633     if (KIS.KernelDeinitCB) {
634       if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
635         indicatePessimisticFixpoint();
636       KernelDeinitCB = KIS.KernelDeinitCB;
637     }
638     SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
639     ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
640     ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
641     return *this;
642   }
643 
644   KernelInfoState operator&=(const KernelInfoState &KIS) {
645     return (*this ^= KIS);
646   }
647 
648   ///}
649 };
650 
651 /// Used to map the values physically (in the IR) stored in an offload
652 /// array, to a vector in memory.
653 struct OffloadArray {
654   /// Physical array (in the IR).
655   AllocaInst *Array = nullptr;
656   /// Mapped values.
657   SmallVector<Value *, 8> StoredValues;
658   /// Last stores made in the offload array.
659   SmallVector<StoreInst *, 8> LastAccesses;
660 
661   OffloadArray() = default;
662 
663   /// Initializes the OffloadArray with the values stored in \p Array before
664   /// instruction \p Before is reached. Returns false if the initialization
665   /// fails.
666   /// This MUST be used immediately after the construction of the object.
667   bool initialize(AllocaInst &Array, Instruction &Before) {
668     if (!Array.getAllocatedType()->isArrayTy())
669       return false;
670 
671     if (!getValues(Array, Before))
672       return false;
673 
674     this->Array = &Array;
675     return true;
676   }
677 
678   static const unsigned DeviceIDArgNum = 1;
679   static const unsigned BasePtrsArgNum = 3;
680   static const unsigned PtrsArgNum = 4;
681   static const unsigned SizesArgNum = 5;
682 
683 private:
684   /// Traverses the BasicBlock where \p Array is, collecting the stores made to
685   /// \p Array, leaving StoredValues with the values stored before the
686   /// instruction \p Before is reached.
687   bool getValues(AllocaInst &Array, Instruction &Before) {
688     // Initialize container.
689     const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
690     StoredValues.assign(NumValues, nullptr);
691     LastAccesses.assign(NumValues, nullptr);
692 
693     // TODO: This assumes the instruction \p Before is in the same
694     //  BasicBlock as Array. Make it general, for any control flow graph.
695     BasicBlock *BB = Array.getParent();
696     if (BB != Before.getParent())
697       return false;
698 
699     const DataLayout &DL = Array.getModule()->getDataLayout();
700     const unsigned int PointerSize = DL.getPointerSize();
701 
702     for (Instruction &I : *BB) {
703       if (&I == &Before)
704         break;
705 
706       if (!isa<StoreInst>(&I))
707         continue;
708 
709       auto *S = cast<StoreInst>(&I);
710       int64_t Offset = -1;
711       auto *Dst =
712           GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
713       if (Dst == &Array) {
714         int64_t Idx = Offset / PointerSize;
715         StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
716         LastAccesses[Idx] = S;
717       }
718     }
719 
720     return isFilled();
721   }
722 
723   /// Returns true if all values in StoredValues and
724   /// LastAccesses are not nullptrs.
725   bool isFilled() {
726     const unsigned NumValues = StoredValues.size();
727     for (unsigned I = 0; I < NumValues; ++I) {
728       if (!StoredValues[I] || !LastAccesses[I])
729         return false;
730     }
731 
732     return true;
733   }
734 };
735 
736 struct OpenMPOpt {
737 
738   using OptimizationRemarkGetter =
739       function_ref<OptimizationRemarkEmitter &(Function *)>;
740 
741   OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
742             OptimizationRemarkGetter OREGetter,
743             OMPInformationCache &OMPInfoCache, Attributor &A)
744       : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
745         OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
746 
747   /// Check if any remarks are enabled for openmp-opt
748   bool remarksEnabled() {
749     auto &Ctx = M.getContext();
750     return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
751   }
752 
753   /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
754   bool run(bool IsModulePass) {
755     if (SCC.empty())
756       return false;
757 
758     bool Changed = false;
759 
760     LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
761                       << " functions in a slice with "
762                       << OMPInfoCache.ModuleSlice.size() << " functions\n");
763 
764     if (IsModulePass) {
765       Changed |= runAttributor(IsModulePass);
766 
767       // Recollect uses, in case Attributor deleted any.
768       OMPInfoCache.recollectUses();
769 
770       // TODO: This should be folded into buildCustomStateMachine.
771       Changed |= rewriteDeviceCodeStateMachine();
772 
773       if (remarksEnabled())
774         analysisGlobalization();
775     } else {
776       if (PrintICVValues)
777         printICVs();
778       if (PrintOpenMPKernels)
779         printKernels();
780 
781       Changed |= runAttributor(IsModulePass);
782 
783       // Recollect uses, in case Attributor deleted any.
784       OMPInfoCache.recollectUses();
785 
786       Changed |= deleteParallelRegions();
787 
788       if (HideMemoryTransferLatency)
789         Changed |= hideMemTransfersLatency();
790       Changed |= deduplicateRuntimeCalls();
791       if (EnableParallelRegionMerging) {
792         if (mergeParallelRegions()) {
793           deduplicateRuntimeCalls();
794           Changed = true;
795         }
796       }
797     }
798 
799     return Changed;
800   }
801 
802   /// Print initial ICV values for testing.
803   /// FIXME: This should be done from the Attributor once it is added.
804   void printICVs() const {
805     InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
806                                  ICV_proc_bind};
807 
808     for (Function *F : OMPInfoCache.ModuleSlice) {
809       for (auto ICV : ICVs) {
810         auto ICVInfo = OMPInfoCache.ICVs[ICV];
811         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
812           return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
813                      << " Value: "
814                      << (ICVInfo.InitValue
815                              ? toString(ICVInfo.InitValue->getValue(), 10, true)
816                              : "IMPLEMENTATION_DEFINED");
817         };
818 
819         emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
820       }
821     }
822   }
823 
824   /// Print OpenMP GPU kernels for testing.
825   void printKernels() const {
826     for (Function *F : SCC) {
827       if (!OMPInfoCache.Kernels.count(F))
828         continue;
829 
830       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
831         return ORA << "OpenMP GPU kernel "
832                    << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
833       };
834 
835       emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
836     }
837   }
838 
839   /// Return the call if \p U is a callee use in a regular call. If \p RFI is
840   /// given it has to be the callee or a nullptr is returned.
841   static CallInst *getCallIfRegularCall(
842       Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
843     CallInst *CI = dyn_cast<CallInst>(U.getUser());
844     if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
845         (!RFI ||
846          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
847       return CI;
848     return nullptr;
849   }
850 
851   /// Return the call if \p V is a regular call. If \p RFI is given it has to be
852   /// the callee or a nullptr is returned.
853   static CallInst *getCallIfRegularCall(
854       Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
855     CallInst *CI = dyn_cast<CallInst>(&V);
856     if (CI && !CI->hasOperandBundles() &&
857         (!RFI ||
858          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
859       return CI;
860     return nullptr;
861   }
862 
863 private:
864   /// Merge parallel regions when it is safe.
865   bool mergeParallelRegions() {
866     const unsigned CallbackCalleeOperand = 2;
867     const unsigned CallbackFirstArgOperand = 3;
868     using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
869 
870     // Check if there are any __kmpc_fork_call calls to merge.
871     OMPInformationCache::RuntimeFunctionInfo &RFI =
872         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
873 
874     if (!RFI.Declaration)
875       return false;
876 
877     // Unmergable calls that prevent merging a parallel region.
878     OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
879         OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
880         OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
881     };
882 
883     bool Changed = false;
884     LoopInfo *LI = nullptr;
885     DominatorTree *DT = nullptr;
886 
887     SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
888 
889     BasicBlock *StartBB = nullptr, *EndBB = nullptr;
890     auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
891                          BasicBlock &ContinuationIP) {
892       BasicBlock *CGStartBB = CodeGenIP.getBlock();
893       BasicBlock *CGEndBB =
894           SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
895       assert(StartBB != nullptr && "StartBB should not be null");
896       CGStartBB->getTerminator()->setSuccessor(0, StartBB);
897       assert(EndBB != nullptr && "EndBB should not be null");
898       EndBB->getTerminator()->setSuccessor(0, CGEndBB);
899     };
900 
901     auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
902                       Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
903       ReplacementValue = &Inner;
904       return CodeGenIP;
905     };
906 
907     auto FiniCB = [&](InsertPointTy CodeGenIP) {};
908 
909     /// Create a sequential execution region within a merged parallel region,
910     /// encapsulated in a master construct with a barrier for synchronization.
911     auto CreateSequentialRegion = [&](Function *OuterFn,
912                                       BasicBlock *OuterPredBB,
913                                       Instruction *SeqStartI,
914                                       Instruction *SeqEndI) {
915       // Isolate the instructions of the sequential region to a separate
916       // block.
917       BasicBlock *ParentBB = SeqStartI->getParent();
918       BasicBlock *SeqEndBB =
919           SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
920       BasicBlock *SeqAfterBB =
921           SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
922       BasicBlock *SeqStartBB =
923           SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
924 
925       assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
926              "Expected a different CFG");
927       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
928       ParentBB->getTerminator()->eraseFromParent();
929 
930       auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
931                            BasicBlock &ContinuationIP) {
932         BasicBlock *CGStartBB = CodeGenIP.getBlock();
933         BasicBlock *CGEndBB =
934             SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
935         assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
936         CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
937         assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
938         SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
939       };
940       auto FiniCB = [&](InsertPointTy CodeGenIP) {};
941 
942       // Find outputs from the sequential region to outside users and
943       // broadcast their values to them.
944       for (Instruction &I : *SeqStartBB) {
945         SmallPtrSet<Instruction *, 4> OutsideUsers;
946         for (User *Usr : I.users()) {
947           Instruction &UsrI = *cast<Instruction>(Usr);
948           // Ignore outputs to LT intrinsics, code extraction for the merged
949           // parallel region will fix them.
950           if (UsrI.isLifetimeStartOrEnd())
951             continue;
952 
953           if (UsrI.getParent() != SeqStartBB)
954             OutsideUsers.insert(&UsrI);
955         }
956 
957         if (OutsideUsers.empty())
958           continue;
959 
960         // Emit an alloca in the outer region to store the broadcasted
961         // value.
962         const DataLayout &DL = M.getDataLayout();
963         AllocaInst *AllocaI = new AllocaInst(
964             I.getType(), DL.getAllocaAddrSpace(), nullptr,
965             I.getName() + ".seq.output.alloc", &OuterFn->front().front());
966 
967         // Emit a store instruction in the sequential BB to update the
968         // value.
969         new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
970 
971         // Emit a load instruction and replace the use of the output value
972         // with it.
973         for (Instruction *UsrI : OutsideUsers) {
974           LoadInst *LoadI = new LoadInst(
975               I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
976           UsrI->replaceUsesOfWith(&I, LoadI);
977         }
978       }
979 
980       OpenMPIRBuilder::LocationDescription Loc(
981           InsertPointTy(ParentBB, ParentBB->end()), DL);
982       InsertPointTy SeqAfterIP =
983           OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
984 
985       OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
986 
987       BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
988 
989       LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
990                         << "\n");
991     };
992 
993     // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
994     // contained in BB and only separated by instructions that can be
995     // redundantly executed in parallel. The block BB is split before the first
996     // call (in MergableCIs) and after the last so the entire region we merge
997     // into a single parallel region is contained in a single basic block
998     // without any other instructions. We use the OpenMPIRBuilder to outline
999     // that block and call the resulting function via __kmpc_fork_call.
1000     auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) {
1001       // TODO: Change the interface to allow single CIs expanded, e.g, to
1002       // include an outer loop.
1003       assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1004 
1005       auto Remark = [&](OptimizationRemark OR) {
1006         OR << "Parallel region merged with parallel region"
1007            << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1008         for (auto *CI : llvm::drop_begin(MergableCIs)) {
1009           OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1010           if (CI != MergableCIs.back())
1011             OR << ", ";
1012         }
1013         return OR << ".";
1014       };
1015 
1016       emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1017 
1018       Function *OriginalFn = BB->getParent();
1019       LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1020                         << " parallel regions in " << OriginalFn->getName()
1021                         << "\n");
1022 
1023       // Isolate the calls to merge in a separate block.
1024       EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1025       BasicBlock *AfterBB =
1026           SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1027       StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1028                            "omp.par.merged");
1029 
1030       assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1031       const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1032       BB->getTerminator()->eraseFromParent();
1033 
1034       // Create sequential regions for sequential instructions that are
1035       // in-between mergable parallel regions.
1036       for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1037            It != End; ++It) {
1038         Instruction *ForkCI = *It;
1039         Instruction *NextForkCI = *(It + 1);
1040 
1041         // Continue if there are not in-between instructions.
1042         if (ForkCI->getNextNode() == NextForkCI)
1043           continue;
1044 
1045         CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1046                                NextForkCI->getPrevNode());
1047       }
1048 
1049       OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1050                                                DL);
1051       IRBuilder<>::InsertPoint AllocaIP(
1052           &OriginalFn->getEntryBlock(),
1053           OriginalFn->getEntryBlock().getFirstInsertionPt());
1054       // Create the merged parallel region with default proc binding, to
1055       // avoid overriding binding settings, and without explicit cancellation.
1056       InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1057           Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1058           OMP_PROC_BIND_default, /* IsCancellable */ false);
1059       BranchInst::Create(AfterBB, AfterIP.getBlock());
1060 
1061       // Perform the actual outlining.
1062       OMPInfoCache.OMPBuilder.finalize(OriginalFn,
1063                                        /* AllowExtractorSinking */ true);
1064 
1065       Function *OutlinedFn = MergableCIs.front()->getCaller();
1066 
1067       // Replace the __kmpc_fork_call calls with direct calls to the outlined
1068       // callbacks.
1069       SmallVector<Value *, 8> Args;
1070       for (auto *CI : MergableCIs) {
1071         Value *Callee =
1072             CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts();
1073         FunctionType *FT =
1074             cast<FunctionType>(Callee->getType()->getPointerElementType());
1075         Args.clear();
1076         Args.push_back(OutlinedFn->getArg(0));
1077         Args.push_back(OutlinedFn->getArg(1));
1078         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
1079              U < E; ++U)
1080           Args.push_back(CI->getArgOperand(U));
1081 
1082         CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
1083         if (CI->getDebugLoc())
1084           NewCI->setDebugLoc(CI->getDebugLoc());
1085 
1086         // Forward parameter attributes from the callback to the callee.
1087         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
1088              U < E; ++U)
1089           for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1090             NewCI->addParamAttr(
1091                 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1092 
1093         // Emit an explicit barrier to replace the implicit fork-join barrier.
1094         if (CI != MergableCIs.back()) {
1095           // TODO: Remove barrier if the merged parallel region includes the
1096           // 'nowait' clause.
1097           OMPInfoCache.OMPBuilder.createBarrier(
1098               InsertPointTy(NewCI->getParent(),
1099                             NewCI->getNextNode()->getIterator()),
1100               OMPD_parallel);
1101         }
1102 
1103         CI->eraseFromParent();
1104       }
1105 
1106       assert(OutlinedFn != OriginalFn && "Outlining failed");
1107       CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1108       CGUpdater.reanalyzeFunction(*OriginalFn);
1109 
1110       NumOpenMPParallelRegionsMerged += MergableCIs.size();
1111 
1112       return true;
1113     };
1114 
1115     // Helper function that identifes sequences of
1116     // __kmpc_fork_call uses in a basic block.
1117     auto DetectPRsCB = [&](Use &U, Function &F) {
1118       CallInst *CI = getCallIfRegularCall(U, &RFI);
1119       BB2PRMap[CI->getParent()].insert(CI);
1120 
1121       return false;
1122     };
1123 
1124     BB2PRMap.clear();
1125     RFI.foreachUse(SCC, DetectPRsCB);
1126     SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1127     // Find mergable parallel regions within a basic block that are
1128     // safe to merge, that is any in-between instructions can safely
1129     // execute in parallel after merging.
1130     // TODO: support merging across basic-blocks.
1131     for (auto &It : BB2PRMap) {
1132       auto &CIs = It.getSecond();
1133       if (CIs.size() < 2)
1134         continue;
1135 
1136       BasicBlock *BB = It.getFirst();
1137       SmallVector<CallInst *, 4> MergableCIs;
1138 
1139       /// Returns true if the instruction is mergable, false otherwise.
1140       /// A terminator instruction is unmergable by definition since merging
1141       /// works within a BB. Instructions before the mergable region are
1142       /// mergable if they are not calls to OpenMP runtime functions that may
1143       /// set different execution parameters for subsequent parallel regions.
1144       /// Instructions in-between parallel regions are mergable if they are not
1145       /// calls to any non-intrinsic function since that may call a non-mergable
1146       /// OpenMP runtime function.
1147       auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1148         // We do not merge across BBs, hence return false (unmergable) if the
1149         // instruction is a terminator.
1150         if (I.isTerminator())
1151           return false;
1152 
1153         if (!isa<CallInst>(&I))
1154           return true;
1155 
1156         CallInst *CI = cast<CallInst>(&I);
1157         if (IsBeforeMergableRegion) {
1158           Function *CalledFunction = CI->getCalledFunction();
1159           if (!CalledFunction)
1160             return false;
1161           // Return false (unmergable) if the call before the parallel
1162           // region calls an explicit affinity (proc_bind) or number of
1163           // threads (num_threads) compiler-generated function. Those settings
1164           // may be incompatible with following parallel regions.
1165           // TODO: ICV tracking to detect compatibility.
1166           for (const auto &RFI : UnmergableCallsInfo) {
1167             if (CalledFunction == RFI.Declaration)
1168               return false;
1169           }
1170         } else {
1171           // Return false (unmergable) if there is a call instruction
1172           // in-between parallel regions when it is not an intrinsic. It
1173           // may call an unmergable OpenMP runtime function in its callpath.
1174           // TODO: Keep track of possible OpenMP calls in the callpath.
1175           if (!isa<IntrinsicInst>(CI))
1176             return false;
1177         }
1178 
1179         return true;
1180       };
1181       // Find maximal number of parallel region CIs that are safe to merge.
1182       for (auto It = BB->begin(), End = BB->end(); It != End;) {
1183         Instruction &I = *It;
1184         ++It;
1185 
1186         if (CIs.count(&I)) {
1187           MergableCIs.push_back(cast<CallInst>(&I));
1188           continue;
1189         }
1190 
1191         // Continue expanding if the instruction is mergable.
1192         if (IsMergable(I, MergableCIs.empty()))
1193           continue;
1194 
1195         // Forward the instruction iterator to skip the next parallel region
1196         // since there is an unmergable instruction which can affect it.
1197         for (; It != End; ++It) {
1198           Instruction &SkipI = *It;
1199           if (CIs.count(&SkipI)) {
1200             LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1201                               << " due to " << I << "\n");
1202             ++It;
1203             break;
1204           }
1205         }
1206 
1207         // Store mergable regions found.
1208         if (MergableCIs.size() > 1) {
1209           MergableCIsVector.push_back(MergableCIs);
1210           LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1211                             << " parallel regions in block " << BB->getName()
1212                             << " of function " << BB->getParent()->getName()
1213                             << "\n";);
1214         }
1215 
1216         MergableCIs.clear();
1217       }
1218 
1219       if (!MergableCIsVector.empty()) {
1220         Changed = true;
1221 
1222         for (auto &MergableCIs : MergableCIsVector)
1223           Merge(MergableCIs, BB);
1224         MergableCIsVector.clear();
1225       }
1226     }
1227 
1228     if (Changed) {
1229       /// Re-collect use for fork calls, emitted barrier calls, and
1230       /// any emitted master/end_master calls.
1231       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1232       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1233       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1234       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1235     }
1236 
1237     return Changed;
1238   }
1239 
1240   /// Try to delete parallel regions if possible.
1241   bool deleteParallelRegions() {
1242     const unsigned CallbackCalleeOperand = 2;
1243 
1244     OMPInformationCache::RuntimeFunctionInfo &RFI =
1245         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1246 
1247     if (!RFI.Declaration)
1248       return false;
1249 
1250     bool Changed = false;
1251     auto DeleteCallCB = [&](Use &U, Function &) {
1252       CallInst *CI = getCallIfRegularCall(U);
1253       if (!CI)
1254         return false;
1255       auto *Fn = dyn_cast<Function>(
1256           CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1257       if (!Fn)
1258         return false;
1259       if (!Fn->onlyReadsMemory())
1260         return false;
1261       if (!Fn->hasFnAttribute(Attribute::WillReturn))
1262         return false;
1263 
1264       LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1265                         << CI->getCaller()->getName() << "\n");
1266 
1267       auto Remark = [&](OptimizationRemark OR) {
1268         return OR << "Removing parallel region with no side-effects.";
1269       };
1270       emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1271 
1272       CGUpdater.removeCallSite(*CI);
1273       CI->eraseFromParent();
1274       Changed = true;
1275       ++NumOpenMPParallelRegionsDeleted;
1276       return true;
1277     };
1278 
1279     RFI.foreachUse(SCC, DeleteCallCB);
1280 
1281     return Changed;
1282   }
1283 
1284   /// Try to eliminate runtime calls by reusing existing ones.
1285   bool deduplicateRuntimeCalls() {
1286     bool Changed = false;
1287 
1288     RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1289         OMPRTL_omp_get_num_threads,
1290         OMPRTL_omp_in_parallel,
1291         OMPRTL_omp_get_cancellation,
1292         OMPRTL_omp_get_thread_limit,
1293         OMPRTL_omp_get_supported_active_levels,
1294         OMPRTL_omp_get_level,
1295         OMPRTL_omp_get_ancestor_thread_num,
1296         OMPRTL_omp_get_team_size,
1297         OMPRTL_omp_get_active_level,
1298         OMPRTL_omp_in_final,
1299         OMPRTL_omp_get_proc_bind,
1300         OMPRTL_omp_get_num_places,
1301         OMPRTL_omp_get_num_procs,
1302         OMPRTL_omp_get_place_num,
1303         OMPRTL_omp_get_partition_num_places,
1304         OMPRTL_omp_get_partition_place_nums};
1305 
1306     // Global-tid is handled separately.
1307     SmallSetVector<Value *, 16> GTIdArgs;
1308     collectGlobalThreadIdArguments(GTIdArgs);
1309     LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1310                       << " global thread ID arguments\n");
1311 
1312     for (Function *F : SCC) {
1313       for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1314         Changed |= deduplicateRuntimeCalls(
1315             *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1316 
1317       // __kmpc_global_thread_num is special as we can replace it with an
1318       // argument in enough cases to make it worth trying.
1319       Value *GTIdArg = nullptr;
1320       for (Argument &Arg : F->args())
1321         if (GTIdArgs.count(&Arg)) {
1322           GTIdArg = &Arg;
1323           break;
1324         }
1325       Changed |= deduplicateRuntimeCalls(
1326           *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1327     }
1328 
1329     return Changed;
1330   }
1331 
1332   /// Tries to hide the latency of runtime calls that involve host to
1333   /// device memory transfers by splitting them into their "issue" and "wait"
1334   /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1335   /// moved downards as much as possible. The "issue" issues the memory transfer
1336   /// asynchronously, returning a handle. The "wait" waits in the returned
1337   /// handle for the memory transfer to finish.
1338   bool hideMemTransfersLatency() {
1339     auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1340     bool Changed = false;
1341     auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1342       auto *RTCall = getCallIfRegularCall(U, &RFI);
1343       if (!RTCall)
1344         return false;
1345 
1346       OffloadArray OffloadArrays[3];
1347       if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1348         return false;
1349 
1350       LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1351 
1352       // TODO: Check if can be moved upwards.
1353       bool WasSplit = false;
1354       Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1355       if (WaitMovementPoint)
1356         WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1357 
1358       Changed |= WasSplit;
1359       return WasSplit;
1360     };
1361     RFI.foreachUse(SCC, SplitMemTransfers);
1362 
1363     return Changed;
1364   }
1365 
1366   void analysisGlobalization() {
1367     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1368 
1369     auto CheckGlobalization = [&](Use &U, Function &Decl) {
1370       if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1371         auto Remark = [&](OptimizationRemarkMissed ORM) {
1372           return ORM
1373                  << "Found thread data sharing on the GPU. "
1374                  << "Expect degraded performance due to data globalization.";
1375         };
1376         emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1377       }
1378 
1379       return false;
1380     };
1381 
1382     RFI.foreachUse(SCC, CheckGlobalization);
1383   }
1384 
1385   /// Maps the values stored in the offload arrays passed as arguments to
1386   /// \p RuntimeCall into the offload arrays in \p OAs.
1387   bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1388                                 MutableArrayRef<OffloadArray> OAs) {
1389     assert(OAs.size() == 3 && "Need space for three offload arrays!");
1390 
1391     // A runtime call that involves memory offloading looks something like:
1392     // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1393     //   i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1394     // ...)
1395     // So, the idea is to access the allocas that allocate space for these
1396     // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1397     // Therefore:
1398     // i8** %offload_baseptrs.
1399     Value *BasePtrsArg =
1400         RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1401     // i8** %offload_ptrs.
1402     Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1403     // i8** %offload_sizes.
1404     Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1405 
1406     // Get values stored in **offload_baseptrs.
1407     auto *V = getUnderlyingObject(BasePtrsArg);
1408     if (!isa<AllocaInst>(V))
1409       return false;
1410     auto *BasePtrsArray = cast<AllocaInst>(V);
1411     if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1412       return false;
1413 
1414     // Get values stored in **offload_baseptrs.
1415     V = getUnderlyingObject(PtrsArg);
1416     if (!isa<AllocaInst>(V))
1417       return false;
1418     auto *PtrsArray = cast<AllocaInst>(V);
1419     if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1420       return false;
1421 
1422     // Get values stored in **offload_sizes.
1423     V = getUnderlyingObject(SizesArg);
1424     // If it's a [constant] global array don't analyze it.
1425     if (isa<GlobalValue>(V))
1426       return isa<Constant>(V);
1427     if (!isa<AllocaInst>(V))
1428       return false;
1429 
1430     auto *SizesArray = cast<AllocaInst>(V);
1431     if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1432       return false;
1433 
1434     return true;
1435   }
1436 
1437   /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1438   /// For now this is a way to test that the function getValuesInOffloadArrays
1439   /// is working properly.
1440   /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1441   void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1442     assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1443 
1444     LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1445     std::string ValuesStr;
1446     raw_string_ostream Printer(ValuesStr);
1447     std::string Separator = " --- ";
1448 
1449     for (auto *BP : OAs[0].StoredValues) {
1450       BP->print(Printer);
1451       Printer << Separator;
1452     }
1453     LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
1454     ValuesStr.clear();
1455 
1456     for (auto *P : OAs[1].StoredValues) {
1457       P->print(Printer);
1458       Printer << Separator;
1459     }
1460     LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
1461     ValuesStr.clear();
1462 
1463     for (auto *S : OAs[2].StoredValues) {
1464       S->print(Printer);
1465       Printer << Separator;
1466     }
1467     LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
1468   }
1469 
1470   /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1471   /// moved. Returns nullptr if the movement is not possible, or not worth it.
1472   Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1473     // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1474     //  Make it traverse the CFG.
1475 
1476     Instruction *CurrentI = &RuntimeCall;
1477     bool IsWorthIt = false;
1478     while ((CurrentI = CurrentI->getNextNode())) {
1479 
1480       // TODO: Once we detect the regions to be offloaded we should use the
1481       //  alias analysis manager to check if CurrentI may modify one of
1482       //  the offloaded regions.
1483       if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1484         if (IsWorthIt)
1485           return CurrentI;
1486 
1487         return nullptr;
1488       }
1489 
1490       // FIXME: For now if we move it over anything without side effect
1491       //  is worth it.
1492       IsWorthIt = true;
1493     }
1494 
1495     // Return end of BasicBlock.
1496     return RuntimeCall.getParent()->getTerminator();
1497   }
1498 
1499   /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1500   bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1501                                Instruction &WaitMovementPoint) {
1502     // Create stack allocated handle (__tgt_async_info) at the beginning of the
1503     // function. Used for storing information of the async transfer, allowing to
1504     // wait on it later.
1505     auto &IRBuilder = OMPInfoCache.OMPBuilder;
1506     auto *F = RuntimeCall.getCaller();
1507     Instruction *FirstInst = &(F->getEntryBlock().front());
1508     AllocaInst *Handle = new AllocaInst(
1509         IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst);
1510 
1511     // Add "issue" runtime call declaration:
1512     // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1513     //   i8**, i8**, i64*, i64*)
1514     FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1515         M, OMPRTL___tgt_target_data_begin_mapper_issue);
1516 
1517     // Change RuntimeCall call site for its asynchronous version.
1518     SmallVector<Value *, 16> Args;
1519     for (auto &Arg : RuntimeCall.args())
1520       Args.push_back(Arg.get());
1521     Args.push_back(Handle);
1522 
1523     CallInst *IssueCallsite =
1524         CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
1525     RuntimeCall.eraseFromParent();
1526 
1527     // Add "wait" runtime call declaration:
1528     // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1529     FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1530         M, OMPRTL___tgt_target_data_begin_mapper_wait);
1531 
1532     Value *WaitParams[2] = {
1533         IssueCallsite->getArgOperand(
1534             OffloadArray::DeviceIDArgNum), // device_id.
1535         Handle                             // handle to wait on.
1536     };
1537     CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
1538 
1539     return true;
1540   }
1541 
1542   static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1543                                     bool GlobalOnly, bool &SingleChoice) {
1544     if (CurrentIdent == NextIdent)
1545       return CurrentIdent;
1546 
1547     // TODO: Figure out how to actually combine multiple debug locations. For
1548     //       now we just keep an existing one if there is a single choice.
1549     if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1550       SingleChoice = !CurrentIdent;
1551       return NextIdent;
1552     }
1553     return nullptr;
1554   }
1555 
1556   /// Return an `struct ident_t*` value that represents the ones used in the
1557   /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1558   /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1559   /// return value we create one from scratch. We also do not yet combine
1560   /// information, e.g., the source locations, see combinedIdentStruct.
1561   Value *
1562   getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1563                                  Function &F, bool GlobalOnly) {
1564     bool SingleChoice = true;
1565     Value *Ident = nullptr;
1566     auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1567       CallInst *CI = getCallIfRegularCall(U, &RFI);
1568       if (!CI || &F != &Caller)
1569         return false;
1570       Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1571                                   /* GlobalOnly */ true, SingleChoice);
1572       return false;
1573     };
1574     RFI.foreachUse(SCC, CombineIdentStruct);
1575 
1576     if (!Ident || !SingleChoice) {
1577       // The IRBuilder uses the insertion block to get to the module, this is
1578       // unfortunate but we work around it for now.
1579       if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1580         OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1581             &F.getEntryBlock(), F.getEntryBlock().begin()));
1582       // Create a fallback location if non was found.
1583       // TODO: Use the debug locations of the calls instead.
1584       Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr();
1585       Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc);
1586     }
1587     return Ident;
1588   }
1589 
1590   /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1591   /// \p ReplVal if given.
1592   bool deduplicateRuntimeCalls(Function &F,
1593                                OMPInformationCache::RuntimeFunctionInfo &RFI,
1594                                Value *ReplVal = nullptr) {
1595     auto *UV = RFI.getUseVector(F);
1596     if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1597       return false;
1598 
1599     LLVM_DEBUG(
1600         dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1601                << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1602 
1603     assert((!ReplVal || (isa<Argument>(ReplVal) &&
1604                          cast<Argument>(ReplVal)->getParent() == &F)) &&
1605            "Unexpected replacement value!");
1606 
1607     // TODO: Use dominance to find a good position instead.
1608     auto CanBeMoved = [this](CallBase &CB) {
1609       unsigned NumArgs = CB.getNumArgOperands();
1610       if (NumArgs == 0)
1611         return true;
1612       if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1613         return false;
1614       for (unsigned u = 1; u < NumArgs; ++u)
1615         if (isa<Instruction>(CB.getArgOperand(u)))
1616           return false;
1617       return true;
1618     };
1619 
1620     if (!ReplVal) {
1621       for (Use *U : *UV)
1622         if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1623           if (!CanBeMoved(*CI))
1624             continue;
1625 
1626           // If the function is a kernel, dedup will move
1627           // the runtime call right after the kernel init callsite. Otherwise,
1628           // it will move it to the beginning of the caller function.
1629           if (isKernel(F)) {
1630             auto &KernelInitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
1631             auto *KernelInitUV = KernelInitRFI.getUseVector(F);
1632 
1633             if (KernelInitUV->empty())
1634               continue;
1635 
1636             assert(KernelInitUV->size() == 1 &&
1637                    "Expected a single __kmpc_target_init in kernel\n");
1638 
1639             CallInst *KernelInitCI =
1640                 getCallIfRegularCall(*KernelInitUV->front(), &KernelInitRFI);
1641             assert(KernelInitCI &&
1642                    "Expected a call to __kmpc_target_init in kernel\n");
1643 
1644             CI->moveAfter(KernelInitCI);
1645           } else
1646             CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
1647           ReplVal = CI;
1648           break;
1649         }
1650       if (!ReplVal)
1651         return false;
1652     }
1653 
1654     // If we use a call as a replacement value we need to make sure the ident is
1655     // valid at the new location. For now we just pick a global one, either
1656     // existing and used by one of the calls, or created from scratch.
1657     if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1658       if (!CI->arg_empty() &&
1659           CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1660         Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1661                                                       /* GlobalOnly */ true);
1662         CI->setArgOperand(0, Ident);
1663       }
1664     }
1665 
1666     bool Changed = false;
1667     auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1668       CallInst *CI = getCallIfRegularCall(U, &RFI);
1669       if (!CI || CI == ReplVal || &F != &Caller)
1670         return false;
1671       assert(CI->getCaller() == &F && "Unexpected call!");
1672 
1673       auto Remark = [&](OptimizationRemark OR) {
1674         return OR << "OpenMP runtime call "
1675                   << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1676       };
1677       if (CI->getDebugLoc())
1678         emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1679       else
1680         emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1681 
1682       CGUpdater.removeCallSite(*CI);
1683       CI->replaceAllUsesWith(ReplVal);
1684       CI->eraseFromParent();
1685       ++NumOpenMPRuntimeCallsDeduplicated;
1686       Changed = true;
1687       return true;
1688     };
1689     RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1690 
1691     return Changed;
1692   }
1693 
1694   /// Collect arguments that represent the global thread id in \p GTIdArgs.
1695   void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1696     // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1697     //       initialization. We could define an AbstractAttribute instead and
1698     //       run the Attributor here once it can be run as an SCC pass.
1699 
1700     // Helper to check the argument \p ArgNo at all call sites of \p F for
1701     // a GTId.
1702     auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1703       if (!F.hasLocalLinkage())
1704         return false;
1705       for (Use &U : F.uses()) {
1706         if (CallInst *CI = getCallIfRegularCall(U)) {
1707           Value *ArgOp = CI->getArgOperand(ArgNo);
1708           if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1709               getCallIfRegularCall(
1710                   *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1711             continue;
1712         }
1713         return false;
1714       }
1715       return true;
1716     };
1717 
1718     // Helper to identify uses of a GTId as GTId arguments.
1719     auto AddUserArgs = [&](Value &GTId) {
1720       for (Use &U : GTId.uses())
1721         if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1722           if (CI->isArgOperand(&U))
1723             if (Function *Callee = CI->getCalledFunction())
1724               if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1725                 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1726     };
1727 
1728     // The argument users of __kmpc_global_thread_num calls are GTIds.
1729     OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1730         OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1731 
1732     GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1733       if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1734         AddUserArgs(*CI);
1735       return false;
1736     });
1737 
1738     // Transitively search for more arguments by looking at the users of the
1739     // ones we know already. During the search the GTIdArgs vector is extended
1740     // so we cannot cache the size nor can we use a range based for.
1741     for (unsigned u = 0; u < GTIdArgs.size(); ++u)
1742       AddUserArgs(*GTIdArgs[u]);
1743   }
1744 
1745   /// Kernel (=GPU) optimizations and utility functions
1746   ///
1747   ///{{
1748 
1749   /// Check if \p F is a kernel, hence entry point for target offloading.
1750   bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
1751 
1752   /// Cache to remember the unique kernel for a function.
1753   DenseMap<Function *, Optional<Kernel>> UniqueKernelMap;
1754 
1755   /// Find the unique kernel that will execute \p F, if any.
1756   Kernel getUniqueKernelFor(Function &F);
1757 
1758   /// Find the unique kernel that will execute \p I, if any.
1759   Kernel getUniqueKernelFor(Instruction &I) {
1760     return getUniqueKernelFor(*I.getFunction());
1761   }
1762 
1763   /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1764   /// the cases we can avoid taking the address of a function.
1765   bool rewriteDeviceCodeStateMachine();
1766 
1767   ///
1768   ///}}
1769 
1770   /// Emit a remark generically
1771   ///
1772   /// This template function can be used to generically emit a remark. The
1773   /// RemarkKind should be one of the following:
1774   ///   - OptimizationRemark to indicate a successful optimization attempt
1775   ///   - OptimizationRemarkMissed to report a failed optimization attempt
1776   ///   - OptimizationRemarkAnalysis to provide additional information about an
1777   ///     optimization attempt
1778   ///
1779   /// The remark is built using a callback function provided by the caller that
1780   /// takes a RemarkKind as input and returns a RemarkKind.
1781   template <typename RemarkKind, typename RemarkCallBack>
1782   void emitRemark(Instruction *I, StringRef RemarkName,
1783                   RemarkCallBack &&RemarkCB) const {
1784     Function *F = I->getParent()->getParent();
1785     auto &ORE = OREGetter(F);
1786 
1787     if (RemarkName.startswith("OMP"))
1788       ORE.emit([&]() {
1789         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
1790                << " [" << RemarkName << "]";
1791       });
1792     else
1793       ORE.emit(
1794           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
1795   }
1796 
1797   /// Emit a remark on a function.
1798   template <typename RemarkKind, typename RemarkCallBack>
1799   void emitRemark(Function *F, StringRef RemarkName,
1800                   RemarkCallBack &&RemarkCB) const {
1801     auto &ORE = OREGetter(F);
1802 
1803     if (RemarkName.startswith("OMP"))
1804       ORE.emit([&]() {
1805         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
1806                << " [" << RemarkName << "]";
1807       });
1808     else
1809       ORE.emit(
1810           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
1811   }
1812 
1813   /// RAII struct to temporarily change an RTL function's linkage to external.
1814   /// This prevents it from being mistakenly removed by other optimizations.
1815   struct ExternalizationRAII {
1816     ExternalizationRAII(OMPInformationCache &OMPInfoCache,
1817                         RuntimeFunction RFKind)
1818         : Declaration(OMPInfoCache.RFIs[RFKind].Declaration) {
1819       if (!Declaration)
1820         return;
1821 
1822       LinkageType = Declaration->getLinkage();
1823       Declaration->setLinkage(GlobalValue::ExternalLinkage);
1824     }
1825 
1826     ~ExternalizationRAII() {
1827       if (!Declaration)
1828         return;
1829 
1830       Declaration->setLinkage(LinkageType);
1831     }
1832 
1833     Function *Declaration;
1834     GlobalValue::LinkageTypes LinkageType;
1835   };
1836 
1837   /// The underlying module.
1838   Module &M;
1839 
1840   /// The SCC we are operating on.
1841   SmallVectorImpl<Function *> &SCC;
1842 
1843   /// Callback to update the call graph, the first argument is a removed call,
1844   /// the second an optional replacement call.
1845   CallGraphUpdater &CGUpdater;
1846 
1847   /// Callback to get an OptimizationRemarkEmitter from a Function *
1848   OptimizationRemarkGetter OREGetter;
1849 
1850   /// OpenMP-specific information cache. Also Used for Attributor runs.
1851   OMPInformationCache &OMPInfoCache;
1852 
1853   /// Attributor instance.
1854   Attributor &A;
1855 
1856   /// Helper function to run Attributor on SCC.
1857   bool runAttributor(bool IsModulePass) {
1858     if (SCC.empty())
1859       return false;
1860 
1861     // Temporarily make these function have external linkage so the Attributor
1862     // doesn't remove them when we try to look them up later.
1863     ExternalizationRAII Parallel(OMPInfoCache, OMPRTL___kmpc_kernel_parallel);
1864     ExternalizationRAII EndParallel(OMPInfoCache,
1865                                     OMPRTL___kmpc_kernel_end_parallel);
1866     ExternalizationRAII BarrierSPMD(OMPInfoCache,
1867                                     OMPRTL___kmpc_barrier_simple_spmd);
1868     ExternalizationRAII ThreadId(OMPInfoCache,
1869                                  OMPRTL___kmpc_get_hardware_thread_id_in_block);
1870 
1871     registerAAs(IsModulePass);
1872 
1873     ChangeStatus Changed = A.run();
1874 
1875     LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
1876                       << " functions, result: " << Changed << ".\n");
1877 
1878     return Changed == ChangeStatus::CHANGED;
1879   }
1880 
1881   void registerFoldRuntimeCall(RuntimeFunction RF);
1882 
1883   /// Populate the Attributor with abstract attribute opportunities in the
1884   /// function.
1885   void registerAAs(bool IsModulePass);
1886 };
1887 
1888 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
1889   if (!OMPInfoCache.ModuleSlice.count(&F))
1890     return nullptr;
1891 
1892   // Use a scope to keep the lifetime of the CachedKernel short.
1893   {
1894     Optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
1895     if (CachedKernel)
1896       return *CachedKernel;
1897 
1898     // TODO: We should use an AA to create an (optimistic and callback
1899     //       call-aware) call graph. For now we stick to simple patterns that
1900     //       are less powerful, basically the worst fixpoint.
1901     if (isKernel(F)) {
1902       CachedKernel = Kernel(&F);
1903       return *CachedKernel;
1904     }
1905 
1906     CachedKernel = nullptr;
1907     if (!F.hasLocalLinkage()) {
1908 
1909       // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
1910       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1911         return ORA << "Potentially unknown OpenMP target region caller.";
1912       };
1913       emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
1914 
1915       return nullptr;
1916     }
1917   }
1918 
1919   auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
1920     if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
1921       // Allow use in equality comparisons.
1922       if (Cmp->isEquality())
1923         return getUniqueKernelFor(*Cmp);
1924       return nullptr;
1925     }
1926     if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
1927       // Allow direct calls.
1928       if (CB->isCallee(&U))
1929         return getUniqueKernelFor(*CB);
1930 
1931       OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1932           OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1933       // Allow the use in __kmpc_parallel_51 calls.
1934       if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
1935         return getUniqueKernelFor(*CB);
1936       return nullptr;
1937     }
1938     // Disallow every other use.
1939     return nullptr;
1940   };
1941 
1942   // TODO: In the future we want to track more than just a unique kernel.
1943   SmallPtrSet<Kernel, 2> PotentialKernels;
1944   OMPInformationCache::foreachUse(F, [&](const Use &U) {
1945     PotentialKernels.insert(GetUniqueKernelForUse(U));
1946   });
1947 
1948   Kernel K = nullptr;
1949   if (PotentialKernels.size() == 1)
1950     K = *PotentialKernels.begin();
1951 
1952   // Cache the result.
1953   UniqueKernelMap[&F] = K;
1954 
1955   return K;
1956 }
1957 
1958 bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
1959   OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1960       OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1961 
1962   bool Changed = false;
1963   if (!KernelParallelRFI)
1964     return Changed;
1965 
1966   // If we have disabled state machine changes, exit
1967   if (DisableOpenMPOptStateMachineRewrite)
1968     return Changed;
1969 
1970   for (Function *F : SCC) {
1971 
1972     // Check if the function is a use in a __kmpc_parallel_51 call at
1973     // all.
1974     bool UnknownUse = false;
1975     bool KernelParallelUse = false;
1976     unsigned NumDirectCalls = 0;
1977 
1978     SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
1979     OMPInformationCache::foreachUse(*F, [&](Use &U) {
1980       if (auto *CB = dyn_cast<CallBase>(U.getUser()))
1981         if (CB->isCallee(&U)) {
1982           ++NumDirectCalls;
1983           return;
1984         }
1985 
1986       if (isa<ICmpInst>(U.getUser())) {
1987         ToBeReplacedStateMachineUses.push_back(&U);
1988         return;
1989       }
1990 
1991       // Find wrapper functions that represent parallel kernels.
1992       CallInst *CI =
1993           OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
1994       const unsigned int WrapperFunctionArgNo = 6;
1995       if (!KernelParallelUse && CI &&
1996           CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
1997         KernelParallelUse = true;
1998         ToBeReplacedStateMachineUses.push_back(&U);
1999         return;
2000       }
2001       UnknownUse = true;
2002     });
2003 
2004     // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2005     // use.
2006     if (!KernelParallelUse)
2007       continue;
2008 
2009     // If this ever hits, we should investigate.
2010     // TODO: Checking the number of uses is not a necessary restriction and
2011     // should be lifted.
2012     if (UnknownUse || NumDirectCalls != 1 ||
2013         ToBeReplacedStateMachineUses.size() > 2) {
2014       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2015         return ORA << "Parallel region is used in "
2016                    << (UnknownUse ? "unknown" : "unexpected")
2017                    << " ways. Will not attempt to rewrite the state machine.";
2018       };
2019       emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2020       continue;
2021     }
2022 
2023     // Even if we have __kmpc_parallel_51 calls, we (for now) give
2024     // up if the function is not called from a unique kernel.
2025     Kernel K = getUniqueKernelFor(*F);
2026     if (!K) {
2027       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2028         return ORA << "Parallel region is not called from a unique kernel. "
2029                       "Will not attempt to rewrite the state machine.";
2030       };
2031       emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2032       continue;
2033     }
2034 
2035     // We now know F is a parallel body function called only from the kernel K.
2036     // We also identified the state machine uses in which we replace the
2037     // function pointer by a new global symbol for identification purposes. This
2038     // ensures only direct calls to the function are left.
2039 
2040     Module &M = *F->getParent();
2041     Type *Int8Ty = Type::getInt8Ty(M.getContext());
2042 
2043     auto *ID = new GlobalVariable(
2044         M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2045         UndefValue::get(Int8Ty), F->getName() + ".ID");
2046 
2047     for (Use *U : ToBeReplacedStateMachineUses)
2048       U->set(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
2049           ID, U->get()->getType()));
2050 
2051     ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2052 
2053     Changed = true;
2054   }
2055 
2056   return Changed;
2057 }
2058 
2059 /// Abstract Attribute for tracking ICV values.
2060 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2061   using Base = StateWrapper<BooleanState, AbstractAttribute>;
2062   AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2063 
2064   void initialize(Attributor &A) override {
2065     Function *F = getAnchorScope();
2066     if (!F || !A.isFunctionIPOAmendable(*F))
2067       indicatePessimisticFixpoint();
2068   }
2069 
2070   /// Returns true if value is assumed to be tracked.
2071   bool isAssumedTracked() const { return getAssumed(); }
2072 
2073   /// Returns true if value is known to be tracked.
2074   bool isKnownTracked() const { return getAssumed(); }
2075 
2076   /// Create an abstract attribute biew for the position \p IRP.
2077   static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2078 
2079   /// Return the value with which \p I can be replaced for specific \p ICV.
2080   virtual Optional<Value *> getReplacementValue(InternalControlVar ICV,
2081                                                 const Instruction *I,
2082                                                 Attributor &A) const {
2083     return None;
2084   }
2085 
2086   /// Return an assumed unique ICV value if a single candidate is found. If
2087   /// there cannot be one, return a nullptr. If it is not clear yet, return the
2088   /// Optional::NoneType.
2089   virtual Optional<Value *>
2090   getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2091 
2092   // Currently only nthreads is being tracked.
2093   // this array will only grow with time.
2094   InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2095 
2096   /// See AbstractAttribute::getName()
2097   const std::string getName() const override { return "AAICVTracker"; }
2098 
2099   /// See AbstractAttribute::getIdAddr()
2100   const char *getIdAddr() const override { return &ID; }
2101 
2102   /// This function should return true if the type of the \p AA is AAICVTracker
2103   static bool classof(const AbstractAttribute *AA) {
2104     return (AA->getIdAddr() == &ID);
2105   }
2106 
2107   static const char ID;
2108 };
2109 
2110 struct AAICVTrackerFunction : public AAICVTracker {
2111   AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2112       : AAICVTracker(IRP, A) {}
2113 
2114   // FIXME: come up with better string.
2115   const std::string getAsStr() const override { return "ICVTrackerFunction"; }
2116 
2117   // FIXME: come up with some stats.
2118   void trackStatistics() const override {}
2119 
2120   /// We don't manifest anything for this AA.
2121   ChangeStatus manifest(Attributor &A) override {
2122     return ChangeStatus::UNCHANGED;
2123   }
2124 
2125   // Map of ICV to their values at specific program point.
2126   EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
2127                   InternalControlVar::ICV___last>
2128       ICVReplacementValuesMap;
2129 
2130   ChangeStatus updateImpl(Attributor &A) override {
2131     ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2132 
2133     Function *F = getAnchorScope();
2134 
2135     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2136 
2137     for (InternalControlVar ICV : TrackableICVs) {
2138       auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2139 
2140       auto &ValuesMap = ICVReplacementValuesMap[ICV];
2141       auto TrackValues = [&](Use &U, Function &) {
2142         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2143         if (!CI)
2144           return false;
2145 
2146         // FIXME: handle setters with more that 1 arguments.
2147         /// Track new value.
2148         if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2149           HasChanged = ChangeStatus::CHANGED;
2150 
2151         return false;
2152       };
2153 
2154       auto CallCheck = [&](Instruction &I) {
2155         Optional<Value *> ReplVal = getValueForCall(A, &I, ICV);
2156         if (ReplVal.hasValue() &&
2157             ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2158           HasChanged = ChangeStatus::CHANGED;
2159 
2160         return true;
2161       };
2162 
2163       // Track all changes of an ICV.
2164       SetterRFI.foreachUse(TrackValues, F);
2165 
2166       bool UsedAssumedInformation = false;
2167       A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2168                                 UsedAssumedInformation,
2169                                 /* CheckBBLivenessOnly */ true);
2170 
2171       /// TODO: Figure out a way to avoid adding entry in
2172       /// ICVReplacementValuesMap
2173       Instruction *Entry = &F->getEntryBlock().front();
2174       if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2175         ValuesMap.insert(std::make_pair(Entry, nullptr));
2176     }
2177 
2178     return HasChanged;
2179   }
2180 
2181   /// Hepler to check if \p I is a call and get the value for it if it is
2182   /// unique.
2183   Optional<Value *> getValueForCall(Attributor &A, const Instruction *I,
2184                                     InternalControlVar &ICV) const {
2185 
2186     const auto *CB = dyn_cast<CallBase>(I);
2187     if (!CB || CB->hasFnAttr("no_openmp") ||
2188         CB->hasFnAttr("no_openmp_routines"))
2189       return None;
2190 
2191     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2192     auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2193     auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2194     Function *CalledFunction = CB->getCalledFunction();
2195 
2196     // Indirect call, assume ICV changes.
2197     if (CalledFunction == nullptr)
2198       return nullptr;
2199     if (CalledFunction == GetterRFI.Declaration)
2200       return None;
2201     if (CalledFunction == SetterRFI.Declaration) {
2202       if (ICVReplacementValuesMap[ICV].count(I))
2203         return ICVReplacementValuesMap[ICV].lookup(I);
2204 
2205       return nullptr;
2206     }
2207 
2208     // Since we don't know, assume it changes the ICV.
2209     if (CalledFunction->isDeclaration())
2210       return nullptr;
2211 
2212     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2213         *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2214 
2215     if (ICVTrackingAA.isAssumedTracked())
2216       return ICVTrackingAA.getUniqueReplacementValue(ICV);
2217 
2218     // If we don't know, assume it changes.
2219     return nullptr;
2220   }
2221 
2222   // We don't check unique value for a function, so return None.
2223   Optional<Value *>
2224   getUniqueReplacementValue(InternalControlVar ICV) const override {
2225     return None;
2226   }
2227 
2228   /// Return the value with which \p I can be replaced for specific \p ICV.
2229   Optional<Value *> getReplacementValue(InternalControlVar ICV,
2230                                         const Instruction *I,
2231                                         Attributor &A) const override {
2232     const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2233     if (ValuesMap.count(I))
2234       return ValuesMap.lookup(I);
2235 
2236     SmallVector<const Instruction *, 16> Worklist;
2237     SmallPtrSet<const Instruction *, 16> Visited;
2238     Worklist.push_back(I);
2239 
2240     Optional<Value *> ReplVal;
2241 
2242     while (!Worklist.empty()) {
2243       const Instruction *CurrInst = Worklist.pop_back_val();
2244       if (!Visited.insert(CurrInst).second)
2245         continue;
2246 
2247       const BasicBlock *CurrBB = CurrInst->getParent();
2248 
2249       // Go up and look for all potential setters/calls that might change the
2250       // ICV.
2251       while ((CurrInst = CurrInst->getPrevNode())) {
2252         if (ValuesMap.count(CurrInst)) {
2253           Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2254           // Unknown value, track new.
2255           if (!ReplVal.hasValue()) {
2256             ReplVal = NewReplVal;
2257             break;
2258           }
2259 
2260           // If we found a new value, we can't know the icv value anymore.
2261           if (NewReplVal.hasValue())
2262             if (ReplVal != NewReplVal)
2263               return nullptr;
2264 
2265           break;
2266         }
2267 
2268         Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV);
2269         if (!NewReplVal.hasValue())
2270           continue;
2271 
2272         // Unknown value, track new.
2273         if (!ReplVal.hasValue()) {
2274           ReplVal = NewReplVal;
2275           break;
2276         }
2277 
2278         // if (NewReplVal.hasValue())
2279         // We found a new value, we can't know the icv value anymore.
2280         if (ReplVal != NewReplVal)
2281           return nullptr;
2282       }
2283 
2284       // If we are in the same BB and we have a value, we are done.
2285       if (CurrBB == I->getParent() && ReplVal.hasValue())
2286         return ReplVal;
2287 
2288       // Go through all predecessors and add terminators for analysis.
2289       for (const BasicBlock *Pred : predecessors(CurrBB))
2290         if (const Instruction *Terminator = Pred->getTerminator())
2291           Worklist.push_back(Terminator);
2292     }
2293 
2294     return ReplVal;
2295   }
2296 };
2297 
2298 struct AAICVTrackerFunctionReturned : AAICVTracker {
2299   AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2300       : AAICVTracker(IRP, A) {}
2301 
2302   // FIXME: come up with better string.
2303   const std::string getAsStr() const override {
2304     return "ICVTrackerFunctionReturned";
2305   }
2306 
2307   // FIXME: come up with some stats.
2308   void trackStatistics() const override {}
2309 
2310   /// We don't manifest anything for this AA.
2311   ChangeStatus manifest(Attributor &A) override {
2312     return ChangeStatus::UNCHANGED;
2313   }
2314 
2315   // Map of ICV to their values at specific program point.
2316   EnumeratedArray<Optional<Value *>, InternalControlVar,
2317                   InternalControlVar::ICV___last>
2318       ICVReplacementValuesMap;
2319 
2320   /// Return the value with which \p I can be replaced for specific \p ICV.
2321   Optional<Value *>
2322   getUniqueReplacementValue(InternalControlVar ICV) const override {
2323     return ICVReplacementValuesMap[ICV];
2324   }
2325 
2326   ChangeStatus updateImpl(Attributor &A) override {
2327     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2328     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2329         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2330 
2331     if (!ICVTrackingAA.isAssumedTracked())
2332       return indicatePessimisticFixpoint();
2333 
2334     for (InternalControlVar ICV : TrackableICVs) {
2335       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2336       Optional<Value *> UniqueICVValue;
2337 
2338       auto CheckReturnInst = [&](Instruction &I) {
2339         Optional<Value *> NewReplVal =
2340             ICVTrackingAA.getReplacementValue(ICV, &I, A);
2341 
2342         // If we found a second ICV value there is no unique returned value.
2343         if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal)
2344           return false;
2345 
2346         UniqueICVValue = NewReplVal;
2347 
2348         return true;
2349       };
2350 
2351       bool UsedAssumedInformation = false;
2352       if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2353                                      UsedAssumedInformation,
2354                                      /* CheckBBLivenessOnly */ true))
2355         UniqueICVValue = nullptr;
2356 
2357       if (UniqueICVValue == ReplVal)
2358         continue;
2359 
2360       ReplVal = UniqueICVValue;
2361       Changed = ChangeStatus::CHANGED;
2362     }
2363 
2364     return Changed;
2365   }
2366 };
2367 
2368 struct AAICVTrackerCallSite : AAICVTracker {
2369   AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2370       : AAICVTracker(IRP, A) {}
2371 
2372   void initialize(Attributor &A) override {
2373     Function *F = getAnchorScope();
2374     if (!F || !A.isFunctionIPOAmendable(*F))
2375       indicatePessimisticFixpoint();
2376 
2377     // We only initialize this AA for getters, so we need to know which ICV it
2378     // gets.
2379     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2380     for (InternalControlVar ICV : TrackableICVs) {
2381       auto ICVInfo = OMPInfoCache.ICVs[ICV];
2382       auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2383       if (Getter.Declaration == getAssociatedFunction()) {
2384         AssociatedICV = ICVInfo.Kind;
2385         return;
2386       }
2387     }
2388 
2389     /// Unknown ICV.
2390     indicatePessimisticFixpoint();
2391   }
2392 
2393   ChangeStatus manifest(Attributor &A) override {
2394     if (!ReplVal.hasValue() || !ReplVal.getValue())
2395       return ChangeStatus::UNCHANGED;
2396 
2397     A.changeValueAfterManifest(*getCtxI(), **ReplVal);
2398     A.deleteAfterManifest(*getCtxI());
2399 
2400     return ChangeStatus::CHANGED;
2401   }
2402 
2403   // FIXME: come up with better string.
2404   const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
2405 
2406   // FIXME: come up with some stats.
2407   void trackStatistics() const override {}
2408 
2409   InternalControlVar AssociatedICV;
2410   Optional<Value *> ReplVal;
2411 
2412   ChangeStatus updateImpl(Attributor &A) override {
2413     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2414         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2415 
2416     // We don't have any information, so we assume it changes the ICV.
2417     if (!ICVTrackingAA.isAssumedTracked())
2418       return indicatePessimisticFixpoint();
2419 
2420     Optional<Value *> NewReplVal =
2421         ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
2422 
2423     if (ReplVal == NewReplVal)
2424       return ChangeStatus::UNCHANGED;
2425 
2426     ReplVal = NewReplVal;
2427     return ChangeStatus::CHANGED;
2428   }
2429 
2430   // Return the value with which associated value can be replaced for specific
2431   // \p ICV.
2432   Optional<Value *>
2433   getUniqueReplacementValue(InternalControlVar ICV) const override {
2434     return ReplVal;
2435   }
2436 };
2437 
2438 struct AAICVTrackerCallSiteReturned : AAICVTracker {
2439   AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2440       : AAICVTracker(IRP, A) {}
2441 
2442   // FIXME: come up with better string.
2443   const std::string getAsStr() const override {
2444     return "ICVTrackerCallSiteReturned";
2445   }
2446 
2447   // FIXME: come up with some stats.
2448   void trackStatistics() const override {}
2449 
2450   /// We don't manifest anything for this AA.
2451   ChangeStatus manifest(Attributor &A) override {
2452     return ChangeStatus::UNCHANGED;
2453   }
2454 
2455   // Map of ICV to their values at specific program point.
2456   EnumeratedArray<Optional<Value *>, InternalControlVar,
2457                   InternalControlVar::ICV___last>
2458       ICVReplacementValuesMap;
2459 
2460   /// Return the value with which associated value can be replaced for specific
2461   /// \p ICV.
2462   Optional<Value *>
2463   getUniqueReplacementValue(InternalControlVar ICV) const override {
2464     return ICVReplacementValuesMap[ICV];
2465   }
2466 
2467   ChangeStatus updateImpl(Attributor &A) override {
2468     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2469     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2470         *this, IRPosition::returned(*getAssociatedFunction()),
2471         DepClassTy::REQUIRED);
2472 
2473     // We don't have any information, so we assume it changes the ICV.
2474     if (!ICVTrackingAA.isAssumedTracked())
2475       return indicatePessimisticFixpoint();
2476 
2477     for (InternalControlVar ICV : TrackableICVs) {
2478       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2479       Optional<Value *> NewReplVal =
2480           ICVTrackingAA.getUniqueReplacementValue(ICV);
2481 
2482       if (ReplVal == NewReplVal)
2483         continue;
2484 
2485       ReplVal = NewReplVal;
2486       Changed = ChangeStatus::CHANGED;
2487     }
2488     return Changed;
2489   }
2490 };
2491 
2492 struct AAExecutionDomainFunction : public AAExecutionDomain {
2493   AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2494       : AAExecutionDomain(IRP, A) {}
2495 
2496   const std::string getAsStr() const override {
2497     return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) +
2498            "/" + std::to_string(NumBBs) + " BBs thread 0 only.";
2499   }
2500 
2501   /// See AbstractAttribute::trackStatistics().
2502   void trackStatistics() const override {}
2503 
2504   void initialize(Attributor &A) override {
2505     Function *F = getAnchorScope();
2506     for (const auto &BB : *F)
2507       SingleThreadedBBs.insert(&BB);
2508     NumBBs = SingleThreadedBBs.size();
2509   }
2510 
2511   ChangeStatus manifest(Attributor &A) override {
2512     LLVM_DEBUG({
2513       for (const BasicBlock *BB : SingleThreadedBBs)
2514         dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2515                << BB->getName() << " is executed by a single thread.\n";
2516     });
2517     return ChangeStatus::UNCHANGED;
2518   }
2519 
2520   ChangeStatus updateImpl(Attributor &A) override;
2521 
2522   /// Check if an instruction is executed by a single thread.
2523   bool isExecutedByInitialThreadOnly(const Instruction &I) const override {
2524     return isExecutedByInitialThreadOnly(*I.getParent());
2525   }
2526 
2527   bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2528     return isValidState() && SingleThreadedBBs.contains(&BB);
2529   }
2530 
2531   /// Set of basic blocks that are executed by a single thread.
2532   DenseSet<const BasicBlock *> SingleThreadedBBs;
2533 
2534   /// Total number of basic blocks in this function.
2535   long unsigned NumBBs;
2536 };
2537 
2538 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
2539   Function *F = getAnchorScope();
2540   ReversePostOrderTraversal<Function *> RPOT(F);
2541   auto NumSingleThreadedBBs = SingleThreadedBBs.size();
2542 
2543   bool AllCallSitesKnown;
2544   auto PredForCallSite = [&](AbstractCallSite ACS) {
2545     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
2546         *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
2547         DepClassTy::REQUIRED);
2548     return ACS.isDirectCall() &&
2549            ExecutionDomainAA.isExecutedByInitialThreadOnly(
2550                *ACS.getInstruction());
2551   };
2552 
2553   if (!A.checkForAllCallSites(PredForCallSite, *this,
2554                               /* RequiresAllCallSites */ true,
2555                               AllCallSitesKnown))
2556     SingleThreadedBBs.erase(&F->getEntryBlock());
2557 
2558   auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2559   auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2560 
2561   // Check if the edge into the successor block contains a condition that only
2562   // lets the main thread execute it.
2563   auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
2564     if (!Edge || !Edge->isConditional())
2565       return false;
2566     if (Edge->getSuccessor(0) != SuccessorBB)
2567       return false;
2568 
2569     auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2570     if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2571       return false;
2572 
2573     ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2574     if (!C)
2575       return false;
2576 
2577     // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2578     if (C->isAllOnesValue()) {
2579       auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2580       CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2581       if (!CB)
2582         return false;
2583       const int InitModeArgNo = 1;
2584       auto *ModeCI = dyn_cast<ConstantInt>(CB->getOperand(InitModeArgNo));
2585       return ModeCI && (ModeCI->getSExtValue() & OMP_TGT_EXEC_MODE_GENERIC);
2586     }
2587 
2588     if (C->isZero()) {
2589       // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2590       if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2591         if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2592           return true;
2593 
2594       // Match: 0 == llvm.amdgcn.workitem.id.x()
2595       if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2596         if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2597           return true;
2598     }
2599 
2600     return false;
2601   };
2602 
2603   // Merge all the predecessor states into the current basic block. A basic
2604   // block is executed by a single thread if all of its predecessors are.
2605   auto MergePredecessorStates = [&](BasicBlock *BB) {
2606     if (pred_begin(BB) == pred_end(BB))
2607       return SingleThreadedBBs.contains(BB);
2608 
2609     bool IsInitialThread = true;
2610     for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB);
2611          PredBB != PredEndBB; ++PredBB) {
2612       if (!IsInitialThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()),
2613                                BB))
2614         IsInitialThread &= SingleThreadedBBs.contains(*PredBB);
2615     }
2616 
2617     return IsInitialThread;
2618   };
2619 
2620   for (auto *BB : RPOT) {
2621     if (!MergePredecessorStates(BB))
2622       SingleThreadedBBs.erase(BB);
2623   }
2624 
2625   return (NumSingleThreadedBBs == SingleThreadedBBs.size())
2626              ? ChangeStatus::UNCHANGED
2627              : ChangeStatus::CHANGED;
2628 }
2629 
2630 /// Try to replace memory allocation calls called by a single thread with a
2631 /// static buffer of shared memory.
2632 struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
2633   using Base = StateWrapper<BooleanState, AbstractAttribute>;
2634   AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2635 
2636   /// Create an abstract attribute view for the position \p IRP.
2637   static AAHeapToShared &createForPosition(const IRPosition &IRP,
2638                                            Attributor &A);
2639 
2640   /// Returns true if HeapToShared conversion is assumed to be possible.
2641   virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
2642 
2643   /// Returns true if HeapToShared conversion is assumed and the CB is a
2644   /// callsite to a free operation to be removed.
2645   virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
2646 
2647   /// See AbstractAttribute::getName().
2648   const std::string getName() const override { return "AAHeapToShared"; }
2649 
2650   /// See AbstractAttribute::getIdAddr().
2651   const char *getIdAddr() const override { return &ID; }
2652 
2653   /// This function should return true if the type of the \p AA is
2654   /// AAHeapToShared.
2655   static bool classof(const AbstractAttribute *AA) {
2656     return (AA->getIdAddr() == &ID);
2657   }
2658 
2659   /// Unique ID (due to the unique address)
2660   static const char ID;
2661 };
2662 
2663 struct AAHeapToSharedFunction : public AAHeapToShared {
2664   AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
2665       : AAHeapToShared(IRP, A) {}
2666 
2667   const std::string getAsStr() const override {
2668     return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
2669            " malloc calls eligible.";
2670   }
2671 
2672   /// See AbstractAttribute::trackStatistics().
2673   void trackStatistics() const override {}
2674 
2675   /// This functions finds free calls that will be removed by the
2676   /// HeapToShared transformation.
2677   void findPotentialRemovedFreeCalls(Attributor &A) {
2678     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2679     auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2680 
2681     PotentialRemovedFreeCalls.clear();
2682     // Update free call users of found malloc calls.
2683     for (CallBase *CB : MallocCalls) {
2684       SmallVector<CallBase *, 4> FreeCalls;
2685       for (auto *U : CB->users()) {
2686         CallBase *C = dyn_cast<CallBase>(U);
2687         if (C && C->getCalledFunction() == FreeRFI.Declaration)
2688           FreeCalls.push_back(C);
2689       }
2690 
2691       if (FreeCalls.size() != 1)
2692         continue;
2693 
2694       PotentialRemovedFreeCalls.insert(FreeCalls.front());
2695     }
2696   }
2697 
2698   void initialize(Attributor &A) override {
2699     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2700     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2701 
2702     for (User *U : RFI.Declaration->users())
2703       if (CallBase *CB = dyn_cast<CallBase>(U))
2704         MallocCalls.insert(CB);
2705 
2706     findPotentialRemovedFreeCalls(A);
2707   }
2708 
2709   bool isAssumedHeapToShared(CallBase &CB) const override {
2710     return isValidState() && MallocCalls.count(&CB);
2711   }
2712 
2713   bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
2714     return isValidState() && PotentialRemovedFreeCalls.count(&CB);
2715   }
2716 
2717   ChangeStatus manifest(Attributor &A) override {
2718     if (MallocCalls.empty())
2719       return ChangeStatus::UNCHANGED;
2720 
2721     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2722     auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2723 
2724     Function *F = getAnchorScope();
2725     auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
2726                                             DepClassTy::OPTIONAL);
2727 
2728     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2729     for (CallBase *CB : MallocCalls) {
2730       // Skip replacing this if HeapToStack has already claimed it.
2731       if (HS && HS->isAssumedHeapToStack(*CB))
2732         continue;
2733 
2734       // Find the unique free call to remove it.
2735       SmallVector<CallBase *, 4> FreeCalls;
2736       for (auto *U : CB->users()) {
2737         CallBase *C = dyn_cast<CallBase>(U);
2738         if (C && C->getCalledFunction() == FreeCall.Declaration)
2739           FreeCalls.push_back(C);
2740       }
2741       if (FreeCalls.size() != 1)
2742         continue;
2743 
2744       ConstantInt *AllocSize = dyn_cast<ConstantInt>(CB->getArgOperand(0));
2745 
2746       LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
2747                         << " with " << AllocSize->getZExtValue()
2748                         << " bytes of shared memory\n");
2749 
2750       // Create a new shared memory buffer of the same size as the allocation
2751       // and replace all the uses of the original allocation with it.
2752       Module *M = CB->getModule();
2753       Type *Int8Ty = Type::getInt8Ty(M->getContext());
2754       Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
2755       auto *SharedMem = new GlobalVariable(
2756           *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
2757           UndefValue::get(Int8ArrTy), CB->getName(), nullptr,
2758           GlobalValue::NotThreadLocal,
2759           static_cast<unsigned>(AddressSpace::Shared));
2760       auto *NewBuffer =
2761           ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
2762 
2763       auto Remark = [&](OptimizationRemark OR) {
2764         return OR << "Replaced globalized variable with "
2765                   << ore::NV("SharedMemory", AllocSize->getZExtValue())
2766                   << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ")
2767                   << "of shared memory.";
2768       };
2769       A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
2770 
2771       SharedMem->setAlignment(MaybeAlign(32));
2772 
2773       A.changeValueAfterManifest(*CB, *NewBuffer);
2774       A.deleteAfterManifest(*CB);
2775       A.deleteAfterManifest(*FreeCalls.front());
2776 
2777       NumBytesMovedToSharedMemory += AllocSize->getZExtValue();
2778       Changed = ChangeStatus::CHANGED;
2779     }
2780 
2781     return Changed;
2782   }
2783 
2784   ChangeStatus updateImpl(Attributor &A) override {
2785     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2786     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2787     Function *F = getAnchorScope();
2788 
2789     auto NumMallocCalls = MallocCalls.size();
2790 
2791     // Only consider malloc calls executed by a single thread with a constant.
2792     for (User *U : RFI.Declaration->users()) {
2793       const auto &ED = A.getAAFor<AAExecutionDomain>(
2794           *this, IRPosition::function(*F), DepClassTy::REQUIRED);
2795       if (CallBase *CB = dyn_cast<CallBase>(U))
2796         if (!dyn_cast<ConstantInt>(CB->getArgOperand(0)) ||
2797             !ED.isExecutedByInitialThreadOnly(*CB))
2798           MallocCalls.erase(CB);
2799     }
2800 
2801     findPotentialRemovedFreeCalls(A);
2802 
2803     if (NumMallocCalls != MallocCalls.size())
2804       return ChangeStatus::CHANGED;
2805 
2806     return ChangeStatus::UNCHANGED;
2807   }
2808 
2809   /// Collection of all malloc calls in a function.
2810   SmallPtrSet<CallBase *, 4> MallocCalls;
2811   /// Collection of potentially removed free calls in a function.
2812   SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
2813 };
2814 
2815 struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
2816   using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
2817   AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2818 
2819   /// Statistics are tracked as part of manifest for now.
2820   void trackStatistics() const override {}
2821 
2822   /// See AbstractAttribute::getAsStr()
2823   const std::string getAsStr() const override {
2824     if (!isValidState())
2825       return "<invalid>";
2826     return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
2827                                                             : "generic") +
2828            std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
2829                                                                : "") +
2830            std::string(" #PRs: ") +
2831            std::to_string(ReachedKnownParallelRegions.size()) +
2832            ", #Unknown PRs: " +
2833            std::to_string(ReachedUnknownParallelRegions.size());
2834   }
2835 
2836   /// Create an abstract attribute biew for the position \p IRP.
2837   static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
2838 
2839   /// See AbstractAttribute::getName()
2840   const std::string getName() const override { return "AAKernelInfo"; }
2841 
2842   /// See AbstractAttribute::getIdAddr()
2843   const char *getIdAddr() const override { return &ID; }
2844 
2845   /// This function should return true if the type of the \p AA is AAKernelInfo
2846   static bool classof(const AbstractAttribute *AA) {
2847     return (AA->getIdAddr() == &ID);
2848   }
2849 
2850   static const char ID;
2851 };
2852 
2853 /// The function kernel info abstract attribute, basically, what can we say
2854 /// about a function with regards to the KernelInfoState.
2855 struct AAKernelInfoFunction : AAKernelInfo {
2856   AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
2857       : AAKernelInfo(IRP, A) {}
2858 
2859   SmallPtrSet<Instruction *, 4> GuardedInstructions;
2860 
2861   SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
2862     return GuardedInstructions;
2863   }
2864 
2865   /// See AbstractAttribute::initialize(...).
2866   void initialize(Attributor &A) override {
2867     // This is a high-level transform that might change the constant arguments
2868     // of the init and dinit calls. We need to tell the Attributor about this
2869     // to avoid other parts using the current constant value for simpliication.
2870     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2871 
2872     Function *Fn = getAnchorScope();
2873     if (!OMPInfoCache.Kernels.count(Fn))
2874       return;
2875 
2876     // Add itself to the reaching kernel and set IsKernelEntry.
2877     ReachingKernelEntries.insert(Fn);
2878     IsKernelEntry = true;
2879 
2880     OMPInformationCache::RuntimeFunctionInfo &InitRFI =
2881         OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2882     OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
2883         OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
2884 
2885     // For kernels we perform more initialization work, first we find the init
2886     // and deinit calls.
2887     auto StoreCallBase = [](Use &U,
2888                             OMPInformationCache::RuntimeFunctionInfo &RFI,
2889                             CallBase *&Storage) {
2890       CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
2891       assert(CB &&
2892              "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
2893       assert(!Storage &&
2894              "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
2895       Storage = CB;
2896       return false;
2897     };
2898     InitRFI.foreachUse(
2899         [&](Use &U, Function &) {
2900           StoreCallBase(U, InitRFI, KernelInitCB);
2901           return false;
2902         },
2903         Fn);
2904     DeinitRFI.foreachUse(
2905         [&](Use &U, Function &) {
2906           StoreCallBase(U, DeinitRFI, KernelDeinitCB);
2907           return false;
2908         },
2909         Fn);
2910 
2911     // Ignore kernels without initializers such as global constructors.
2912     if (!KernelInitCB || !KernelDeinitCB) {
2913       indicateOptimisticFixpoint();
2914       return;
2915     }
2916 
2917     // For kernels we might need to initialize/finalize the IsSPMD state and
2918     // we need to register a simplification callback so that the Attributor
2919     // knows the constant arguments to __kmpc_target_init and
2920     // __kmpc_target_deinit might actually change.
2921 
2922     Attributor::SimplifictionCallbackTy StateMachineSimplifyCB =
2923         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2924             bool &UsedAssumedInformation) -> Optional<Value *> {
2925       // IRP represents the "use generic state machine" argument of an
2926       // __kmpc_target_init call. We will answer this one with the internal
2927       // state. As long as we are not in an invalid state, we will create a
2928       // custom state machine so the value should be a `i1 false`. If we are
2929       // in an invalid state, we won't change the value that is in the IR.
2930       if (!isValidState())
2931         return nullptr;
2932       // If we have disabled state machine rewrites, don't make a custom one.
2933       if (DisableOpenMPOptStateMachineRewrite)
2934         return nullptr;
2935       if (AA)
2936         A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2937       UsedAssumedInformation = !isAtFixpoint();
2938       auto *FalseVal =
2939           ConstantInt::getBool(IRP.getAnchorValue().getContext(), 0);
2940       return FalseVal;
2941     };
2942 
2943     Attributor::SimplifictionCallbackTy ModeSimplifyCB =
2944         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2945             bool &UsedAssumedInformation) -> Optional<Value *> {
2946       // IRP represents the "SPMDCompatibilityTracker" argument of an
2947       // __kmpc_target_init or
2948       // __kmpc_target_deinit call. We will answer this one with the internal
2949       // state.
2950       if (!SPMDCompatibilityTracker.isValidState())
2951         return nullptr;
2952       if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2953         if (AA)
2954           A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2955         UsedAssumedInformation = true;
2956       } else {
2957         UsedAssumedInformation = false;
2958       }
2959       auto *Val = ConstantInt::getSigned(
2960           IntegerType::getInt8Ty(IRP.getAnchorValue().getContext()),
2961           SPMDCompatibilityTracker.isAssumed() ? OMP_TGT_EXEC_MODE_SPMD
2962                                                : OMP_TGT_EXEC_MODE_GENERIC);
2963       return Val;
2964     };
2965 
2966     Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB =
2967         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2968             bool &UsedAssumedInformation) -> Optional<Value *> {
2969       // IRP represents the "RequiresFullRuntime" argument of an
2970       // __kmpc_target_init or __kmpc_target_deinit call. We will answer this
2971       // one with the internal state of the SPMDCompatibilityTracker, so if
2972       // generic then true, if SPMD then false.
2973       if (!SPMDCompatibilityTracker.isValidState())
2974         return nullptr;
2975       if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2976         if (AA)
2977           A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2978         UsedAssumedInformation = true;
2979       } else {
2980         UsedAssumedInformation = false;
2981       }
2982       auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
2983                                        !SPMDCompatibilityTracker.isAssumed());
2984       return Val;
2985     };
2986 
2987     constexpr const int InitModeArgNo = 1;
2988     constexpr const int DeinitModeArgNo = 1;
2989     constexpr const int InitUseStateMachineArgNo = 2;
2990     constexpr const int InitRequiresFullRuntimeArgNo = 3;
2991     constexpr const int DeinitRequiresFullRuntimeArgNo = 2;
2992     A.registerSimplificationCallback(
2993         IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo),
2994         StateMachineSimplifyCB);
2995     A.registerSimplificationCallback(
2996         IRPosition::callsite_argument(*KernelInitCB, InitModeArgNo),
2997         ModeSimplifyCB);
2998     A.registerSimplificationCallback(
2999         IRPosition::callsite_argument(*KernelDeinitCB, DeinitModeArgNo),
3000         ModeSimplifyCB);
3001     A.registerSimplificationCallback(
3002         IRPosition::callsite_argument(*KernelInitCB,
3003                                       InitRequiresFullRuntimeArgNo),
3004         IsGenericModeSimplifyCB);
3005     A.registerSimplificationCallback(
3006         IRPosition::callsite_argument(*KernelDeinitCB,
3007                                       DeinitRequiresFullRuntimeArgNo),
3008         IsGenericModeSimplifyCB);
3009 
3010     // Check if we know we are in SPMD-mode already.
3011     ConstantInt *ModeArg =
3012         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
3013     if (ModeArg && (ModeArg->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
3014       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3015     // This is a generic region but SPMDization is disabled so stop tracking.
3016     else if (DisableOpenMPOptSPMDization)
3017       SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3018   }
3019 
3020   /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3021   /// finished now.
3022   ChangeStatus manifest(Attributor &A) override {
3023     // If we are not looking at a kernel with __kmpc_target_init and
3024     // __kmpc_target_deinit call we cannot actually manifest the information.
3025     if (!KernelInitCB || !KernelDeinitCB)
3026       return ChangeStatus::UNCHANGED;
3027 
3028     // Known SPMD-mode kernels need no manifest changes.
3029     if (SPMDCompatibilityTracker.isKnown())
3030       return ChangeStatus::UNCHANGED;
3031 
3032     // If we can we change the execution mode to SPMD-mode otherwise we build a
3033     // custom state machine.
3034     if (!mayContainParallelRegion() || !changeToSPMDMode(A))
3035       buildCustomStateMachine(A);
3036 
3037     return ChangeStatus::CHANGED;
3038   }
3039 
3040   bool changeToSPMDMode(Attributor &A) {
3041     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3042 
3043     if (!SPMDCompatibilityTracker.isAssumed()) {
3044       for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
3045         if (!NonCompatibleI)
3046           continue;
3047 
3048         // Skip diagnostics on calls to known OpenMP runtime functions for now.
3049         if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
3050           if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
3051             continue;
3052 
3053         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3054           ORA << "Value has potential side effects preventing SPMD-mode "
3055                  "execution";
3056           if (isa<CallBase>(NonCompatibleI)) {
3057             ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
3058                    "the called function to override";
3059           }
3060           return ORA << ".";
3061         };
3062         A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
3063                                                  Remark);
3064 
3065         LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
3066                           << *NonCompatibleI << "\n");
3067       }
3068 
3069       return false;
3070     }
3071 
3072     auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3073                                    Instruction *RegionEndI) {
3074       LoopInfo *LI = nullptr;
3075       DominatorTree *DT = nullptr;
3076       MemorySSAUpdater *MSU = nullptr;
3077       using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3078 
3079       BasicBlock *ParentBB = RegionStartI->getParent();
3080       Function *Fn = ParentBB->getParent();
3081       Module &M = *Fn->getParent();
3082 
3083       // Create all the blocks and logic.
3084       // ParentBB:
3085       //    goto RegionCheckTidBB
3086       // RegionCheckTidBB:
3087       //    Tid = __kmpc_hardware_thread_id()
3088       //    if (Tid != 0)
3089       //        goto RegionBarrierBB
3090       // RegionStartBB:
3091       //    <execute instructions guarded>
3092       //    goto RegionEndBB
3093       // RegionEndBB:
3094       //    <store escaping values to shared mem>
3095       //    goto RegionBarrierBB
3096       //  RegionBarrierBB:
3097       //    __kmpc_simple_barrier_spmd()
3098       //    // second barrier is omitted if lacking escaping values.
3099       //    <load escaping values from shared mem>
3100       //    __kmpc_simple_barrier_spmd()
3101       //    goto RegionExitBB
3102       // RegionExitBB:
3103       //    <execute rest of instructions>
3104 
3105       BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3106                                            DT, LI, MSU, "region.guarded.end");
3107       BasicBlock *RegionBarrierBB =
3108           SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3109                      MSU, "region.barrier");
3110       BasicBlock *RegionExitBB =
3111           SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3112                      DT, LI, MSU, "region.exit");
3113       BasicBlock *RegionStartBB =
3114           SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3115 
3116       assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3117              "Expected a different CFG");
3118 
3119       BasicBlock *RegionCheckTidBB = SplitBlock(
3120           ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3121 
3122       // Register basic blocks with the Attributor.
3123       A.registerManifestAddedBasicBlock(*RegionEndBB);
3124       A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3125       A.registerManifestAddedBasicBlock(*RegionExitBB);
3126       A.registerManifestAddedBasicBlock(*RegionStartBB);
3127       A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
3128 
3129       bool HasBroadcastValues = false;
3130       // Find escaping outputs from the guarded region to outside users and
3131       // broadcast their values to them.
3132       for (Instruction &I : *RegionStartBB) {
3133         SmallPtrSet<Instruction *, 4> OutsideUsers;
3134         for (User *Usr : I.users()) {
3135           Instruction &UsrI = *cast<Instruction>(Usr);
3136           if (UsrI.getParent() != RegionStartBB)
3137             OutsideUsers.insert(&UsrI);
3138         }
3139 
3140         if (OutsideUsers.empty())
3141           continue;
3142 
3143         HasBroadcastValues = true;
3144 
3145         // Emit a global variable in shared memory to store the broadcasted
3146         // value.
3147         auto *SharedMem = new GlobalVariable(
3148             M, I.getType(), /* IsConstant */ false,
3149             GlobalValue::InternalLinkage, UndefValue::get(I.getType()),
3150             I.getName() + ".guarded.output.alloc", nullptr,
3151             GlobalValue::NotThreadLocal,
3152             static_cast<unsigned>(AddressSpace::Shared));
3153 
3154         // Emit a store instruction to update the value.
3155         new StoreInst(&I, SharedMem, RegionEndBB->getTerminator());
3156 
3157         LoadInst *LoadI = new LoadInst(I.getType(), SharedMem,
3158                                        I.getName() + ".guarded.output.load",
3159                                        RegionBarrierBB->getTerminator());
3160 
3161         // Emit a load instruction and replace uses of the output value.
3162         for (Instruction *UsrI : OutsideUsers) {
3163           assert(UsrI->getParent() == RegionExitBB &&
3164                  "Expected escaping users in exit region");
3165           UsrI->replaceUsesOfWith(&I, LoadI);
3166         }
3167       }
3168 
3169       auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3170 
3171       // Go to tid check BB in ParentBB.
3172       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
3173       ParentBB->getTerminator()->eraseFromParent();
3174       OpenMPIRBuilder::LocationDescription Loc(
3175           InsertPointTy(ParentBB, ParentBB->end()), DL);
3176       OMPInfoCache.OMPBuilder.updateToLocation(Loc);
3177       auto *SrcLocStr = OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc);
3178       Value *Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr);
3179       BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
3180 
3181       // Add check for Tid in RegionCheckTidBB
3182       RegionCheckTidBB->getTerminator()->eraseFromParent();
3183       OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
3184           InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
3185       OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
3186       FunctionCallee HardwareTidFn =
3187           OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3188               M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3189       Value *Tid =
3190           OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
3191       Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
3192       OMPInfoCache.OMPBuilder.Builder
3193           .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
3194           ->setDebugLoc(DL);
3195 
3196       // First barrier for synchronization, ensures main thread has updated
3197       // values.
3198       FunctionCallee BarrierFn =
3199           OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3200               M, OMPRTL___kmpc_barrier_simple_spmd);
3201       OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
3202           RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
3203       OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid})
3204           ->setDebugLoc(DL);
3205 
3206       // Second barrier ensures workers have read broadcast values.
3207       if (HasBroadcastValues)
3208         CallInst::Create(BarrierFn, {Ident, Tid}, "",
3209                          RegionBarrierBB->getTerminator())
3210             ->setDebugLoc(DL);
3211     };
3212 
3213     auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3214     SmallPtrSet<BasicBlock *, 8> Visited;
3215     for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3216       BasicBlock *BB = GuardedI->getParent();
3217       if (!Visited.insert(BB).second)
3218         continue;
3219 
3220       SmallVector<std::pair<Instruction *, Instruction *>> Reorders;
3221       Instruction *LastEffect = nullptr;
3222       BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
3223       while (++IP != IPEnd) {
3224         if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
3225           continue;
3226         Instruction *I = &*IP;
3227         if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
3228           continue;
3229         if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
3230           LastEffect = nullptr;
3231           continue;
3232         }
3233         if (LastEffect)
3234           Reorders.push_back({I, LastEffect});
3235         LastEffect = &*IP;
3236       }
3237       for (auto &Reorder : Reorders)
3238         Reorder.first->moveBefore(Reorder.second);
3239     }
3240 
3241     SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions;
3242 
3243     for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3244       BasicBlock *BB = GuardedI->getParent();
3245       auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
3246           IRPosition::function(*GuardedI->getFunction()), nullptr,
3247           DepClassTy::NONE);
3248       assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
3249       auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
3250       // Continue if instruction is already guarded.
3251       if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
3252         continue;
3253 
3254       Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
3255       for (Instruction &I : *BB) {
3256         // If instruction I needs to be guarded update the guarded region
3257         // bounds.
3258         if (SPMDCompatibilityTracker.contains(&I)) {
3259           CalleeAAFunction.getGuardedInstructions().insert(&I);
3260           if (GuardedRegionStart)
3261             GuardedRegionEnd = &I;
3262           else
3263             GuardedRegionStart = GuardedRegionEnd = &I;
3264 
3265           continue;
3266         }
3267 
3268         // Instruction I does not need guarding, store
3269         // any region found and reset bounds.
3270         if (GuardedRegionStart) {
3271           GuardedRegions.push_back(
3272               std::make_pair(GuardedRegionStart, GuardedRegionEnd));
3273           GuardedRegionStart = nullptr;
3274           GuardedRegionEnd = nullptr;
3275         }
3276       }
3277     }
3278 
3279     for (auto &GR : GuardedRegions)
3280       CreateGuardedRegion(GR.first, GR.second);
3281 
3282     // Adjust the global exec mode flag that tells the runtime what mode this
3283     // kernel is executed in.
3284     Function *Kernel = getAnchorScope();
3285     GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable(
3286         (Kernel->getName() + "_exec_mode").str());
3287     assert(ExecMode && "Kernel without exec mode?");
3288     assert(ExecMode->getInitializer() && "ExecMode doesn't have initializer!");
3289 
3290     // Set the global exec mode flag to indicate SPMD-Generic mode.
3291     assert(isa<ConstantInt>(ExecMode->getInitializer()) &&
3292            "ExecMode is not an integer!");
3293     const int8_t ExecModeVal =
3294         cast<ConstantInt>(ExecMode->getInitializer())->getSExtValue();
3295     assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
3296            "Initially non-SPMD kernel has SPMD exec mode!");
3297     ExecMode->setInitializer(
3298         ConstantInt::get(ExecMode->getInitializer()->getType(),
3299                          ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
3300 
3301     // Next rewrite the init and deinit calls to indicate we use SPMD-mode now.
3302     const int InitModeArgNo = 1;
3303     const int DeinitModeArgNo = 1;
3304     const int InitUseStateMachineArgNo = 2;
3305     const int InitRequiresFullRuntimeArgNo = 3;
3306     const int DeinitRequiresFullRuntimeArgNo = 2;
3307 
3308     auto &Ctx = getAnchorValue().getContext();
3309     A.changeUseAfterManifest(
3310         KernelInitCB->getArgOperandUse(InitModeArgNo),
3311         *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx),
3312                                 OMP_TGT_EXEC_MODE_SPMD));
3313     A.changeUseAfterManifest(
3314         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
3315         *ConstantInt::getBool(Ctx, 0));
3316     A.changeUseAfterManifest(
3317         KernelDeinitCB->getArgOperandUse(DeinitModeArgNo),
3318         *ConstantInt::getSigned(IntegerType::getInt8Ty(Ctx),
3319                                 OMP_TGT_EXEC_MODE_SPMD));
3320     A.changeUseAfterManifest(
3321         KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo),
3322         *ConstantInt::getBool(Ctx, 0));
3323     A.changeUseAfterManifest(
3324         KernelDeinitCB->getArgOperandUse(DeinitRequiresFullRuntimeArgNo),
3325         *ConstantInt::getBool(Ctx, 0));
3326 
3327     ++NumOpenMPTargetRegionKernelsSPMD;
3328 
3329     auto Remark = [&](OptimizationRemark OR) {
3330       return OR << "Transformed generic-mode kernel to SPMD-mode.";
3331     };
3332     A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
3333     return true;
3334   };
3335 
3336   ChangeStatus buildCustomStateMachine(Attributor &A) {
3337     // If we have disabled state machine rewrites, don't make a custom one
3338     if (DisableOpenMPOptStateMachineRewrite)
3339       return indicatePessimisticFixpoint();
3340 
3341     assert(ReachedKnownParallelRegions.isValidState() &&
3342            "Custom state machine with invalid parallel region states?");
3343 
3344     const int InitModeArgNo = 1;
3345     const int InitUseStateMachineArgNo = 2;
3346 
3347     // Check if the current configuration is non-SPMD and generic state machine.
3348     // If we already have SPMD mode or a custom state machine we do not need to
3349     // go any further. If it is anything but a constant something is weird and
3350     // we give up.
3351     ConstantInt *UseStateMachine = dyn_cast<ConstantInt>(
3352         KernelInitCB->getArgOperand(InitUseStateMachineArgNo));
3353     ConstantInt *Mode =
3354         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitModeArgNo));
3355 
3356     // If we are stuck with generic mode, try to create a custom device (=GPU)
3357     // state machine which is specialized for the parallel regions that are
3358     // reachable by the kernel.
3359     if (!UseStateMachine || UseStateMachine->isZero() || !Mode ||
3360         (Mode->getSExtValue() & OMP_TGT_EXEC_MODE_SPMD))
3361       return ChangeStatus::UNCHANGED;
3362 
3363     // If not SPMD mode, indicate we use a custom state machine now.
3364     auto &Ctx = getAnchorValue().getContext();
3365     auto *FalseVal = ConstantInt::getBool(Ctx, 0);
3366     A.changeUseAfterManifest(
3367         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal);
3368 
3369     // If we don't actually need a state machine we are done here. This can
3370     // happen if there simply are no parallel regions. In the resulting kernel
3371     // all worker threads will simply exit right away, leaving the main thread
3372     // to do the work alone.
3373     if (!mayContainParallelRegion()) {
3374       ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
3375 
3376       auto Remark = [&](OptimizationRemark OR) {
3377         return OR << "Removing unused state machine from generic-mode kernel.";
3378       };
3379       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
3380 
3381       return ChangeStatus::CHANGED;
3382     }
3383 
3384     // Keep track in the statistics of our new shiny custom state machine.
3385     if (ReachedUnknownParallelRegions.empty()) {
3386       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
3387 
3388       auto Remark = [&](OptimizationRemark OR) {
3389         return OR << "Rewriting generic-mode kernel with a customized state "
3390                      "machine.";
3391       };
3392       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
3393     } else {
3394       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
3395 
3396       auto Remark = [&](OptimizationRemarkAnalysis OR) {
3397         return OR << "Generic-mode kernel is executed with a customized state "
3398                      "machine that requires a fallback.";
3399       };
3400       A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
3401 
3402       // Tell the user why we ended up with a fallback.
3403       for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
3404         if (!UnknownParallelRegionCB)
3405           continue;
3406         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3407           return ORA << "Call may contain unknown parallel regions. Use "
3408                      << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
3409                         "override.";
3410         };
3411         A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
3412                                                  "OMP133", Remark);
3413       }
3414     }
3415 
3416     // Create all the blocks:
3417     //
3418     //                       InitCB = __kmpc_target_init(...)
3419     //                       bool IsWorker = InitCB >= 0;
3420     //                       if (IsWorker) {
3421     // SMBeginBB:               __kmpc_barrier_simple_spmd(...);
3422     //                         void *WorkFn;
3423     //                         bool Active = __kmpc_kernel_parallel(&WorkFn);
3424     //                         if (!WorkFn) return;
3425     // SMIsActiveCheckBB:       if (Active) {
3426     // SMIfCascadeCurrentBB:      if      (WorkFn == <ParFn0>)
3427     //                              ParFn0(...);
3428     // SMIfCascadeCurrentBB:      else if (WorkFn == <ParFn1>)
3429     //                              ParFn1(...);
3430     //                            ...
3431     // SMIfCascadeCurrentBB:      else
3432     //                              ((WorkFnTy*)WorkFn)(...);
3433     // SMEndParallelBB:           __kmpc_kernel_end_parallel(...);
3434     //                          }
3435     // SMDoneBB:                __kmpc_barrier_simple_spmd(...);
3436     //                          goto SMBeginBB;
3437     //                       }
3438     // UserCodeEntryBB:      // user code
3439     //                       __kmpc_target_deinit(...)
3440     //
3441     Function *Kernel = getAssociatedFunction();
3442     assert(Kernel && "Expected an associated function!");
3443 
3444     BasicBlock *InitBB = KernelInitCB->getParent();
3445     BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
3446         KernelInitCB->getNextNode(), "thread.user_code.check");
3447     BasicBlock *StateMachineBeginBB = BasicBlock::Create(
3448         Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
3449     BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
3450         Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
3451     BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
3452         Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
3453     BasicBlock *StateMachineIfCascadeCurrentBB =
3454         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3455                            Kernel, UserCodeEntryBB);
3456     BasicBlock *StateMachineEndParallelBB =
3457         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
3458                            Kernel, UserCodeEntryBB);
3459     BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
3460         Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
3461     A.registerManifestAddedBasicBlock(*InitBB);
3462     A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
3463     A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
3464     A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
3465     A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
3466     A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
3467     A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
3468     A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
3469 
3470     const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
3471     ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
3472 
3473     InitBB->getTerminator()->eraseFromParent();
3474     Instruction *IsWorker =
3475         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
3476                          ConstantInt::get(KernelInitCB->getType(), -1),
3477                          "thread.is_worker", InitBB);
3478     IsWorker->setDebugLoc(DLoc);
3479     BranchInst::Create(StateMachineBeginBB, UserCodeEntryBB, IsWorker, InitBB);
3480 
3481     Module &M = *Kernel->getParent();
3482 
3483     // Create local storage for the work function pointer.
3484     const DataLayout &DL = M.getDataLayout();
3485     Type *VoidPtrTy = Type::getInt8PtrTy(Ctx);
3486     Instruction *WorkFnAI =
3487         new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
3488                        "worker.work_fn.addr", &Kernel->getEntryBlock().front());
3489     WorkFnAI->setDebugLoc(DLoc);
3490 
3491     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3492     OMPInfoCache.OMPBuilder.updateToLocation(
3493         OpenMPIRBuilder::LocationDescription(
3494             IRBuilder<>::InsertPoint(StateMachineBeginBB,
3495                                      StateMachineBeginBB->end()),
3496             DLoc));
3497 
3498     Value *Ident = KernelInitCB->getArgOperand(0);
3499     Value *GTid = KernelInitCB;
3500 
3501     FunctionCallee BarrierFn =
3502         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3503             M, OMPRTL___kmpc_barrier_simple_spmd);
3504     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB)
3505         ->setDebugLoc(DLoc);
3506 
3507     if (WorkFnAI->getType()->getPointerAddressSpace() !=
3508         (unsigned int)AddressSpace::Generic) {
3509       WorkFnAI = new AddrSpaceCastInst(
3510           WorkFnAI,
3511           PointerType::getWithSamePointeeType(
3512               cast<PointerType>(WorkFnAI->getType()),
3513               (unsigned int)AddressSpace::Generic),
3514           WorkFnAI->getName() + ".generic", StateMachineBeginBB);
3515       WorkFnAI->setDebugLoc(DLoc);
3516     }
3517 
3518     FunctionCallee KernelParallelFn =
3519         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3520             M, OMPRTL___kmpc_kernel_parallel);
3521     Instruction *IsActiveWorker = CallInst::Create(
3522         KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
3523     IsActiveWorker->setDebugLoc(DLoc);
3524     Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
3525                                        StateMachineBeginBB);
3526     WorkFn->setDebugLoc(DLoc);
3527 
3528     FunctionType *ParallelRegionFnTy = FunctionType::get(
3529         Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
3530         false);
3531     Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast(
3532         WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast",
3533         StateMachineBeginBB);
3534 
3535     Instruction *IsDone =
3536         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
3537                          Constant::getNullValue(VoidPtrTy), "worker.is_done",
3538                          StateMachineBeginBB);
3539     IsDone->setDebugLoc(DLoc);
3540     BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
3541                        IsDone, StateMachineBeginBB)
3542         ->setDebugLoc(DLoc);
3543 
3544     BranchInst::Create(StateMachineIfCascadeCurrentBB,
3545                        StateMachineDoneBarrierBB, IsActiveWorker,
3546                        StateMachineIsActiveCheckBB)
3547         ->setDebugLoc(DLoc);
3548 
3549     Value *ZeroArg =
3550         Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
3551 
3552     // Now that we have most of the CFG skeleton it is time for the if-cascade
3553     // that checks the function pointer we got from the runtime against the
3554     // parallel regions we expect, if there are any.
3555     for (int i = 0, e = ReachedKnownParallelRegions.size(); i < e; ++i) {
3556       auto *ParallelRegion = ReachedKnownParallelRegions[i];
3557       BasicBlock *PRExecuteBB = BasicBlock::Create(
3558           Ctx, "worker_state_machine.parallel_region.execute", Kernel,
3559           StateMachineEndParallelBB);
3560       CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
3561           ->setDebugLoc(DLoc);
3562       BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
3563           ->setDebugLoc(DLoc);
3564 
3565       BasicBlock *PRNextBB =
3566           BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3567                              Kernel, StateMachineEndParallelBB);
3568 
3569       // Check if we need to compare the pointer at all or if we can just
3570       // call the parallel region function.
3571       Value *IsPR;
3572       if (i + 1 < e || !ReachedUnknownParallelRegions.empty()) {
3573         Instruction *CmpI = ICmpInst::Create(
3574             ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion,
3575             "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
3576         CmpI->setDebugLoc(DLoc);
3577         IsPR = CmpI;
3578       } else {
3579         IsPR = ConstantInt::getTrue(Ctx);
3580       }
3581 
3582       BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
3583                          StateMachineIfCascadeCurrentBB)
3584           ->setDebugLoc(DLoc);
3585       StateMachineIfCascadeCurrentBB = PRNextBB;
3586     }
3587 
3588     // At the end of the if-cascade we place the indirect function pointer call
3589     // in case we might need it, that is if there can be parallel regions we
3590     // have not handled in the if-cascade above.
3591     if (!ReachedUnknownParallelRegions.empty()) {
3592       StateMachineIfCascadeCurrentBB->setName(
3593           "worker_state_machine.parallel_region.fallback.execute");
3594       CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "",
3595                        StateMachineIfCascadeCurrentBB)
3596           ->setDebugLoc(DLoc);
3597     }
3598     BranchInst::Create(StateMachineEndParallelBB,
3599                        StateMachineIfCascadeCurrentBB)
3600         ->setDebugLoc(DLoc);
3601 
3602     CallInst::Create(OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3603                          M, OMPRTL___kmpc_kernel_end_parallel),
3604                      {}, "", StateMachineEndParallelBB)
3605         ->setDebugLoc(DLoc);
3606     BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
3607         ->setDebugLoc(DLoc);
3608 
3609     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
3610         ->setDebugLoc(DLoc);
3611     BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
3612         ->setDebugLoc(DLoc);
3613 
3614     return ChangeStatus::CHANGED;
3615   }
3616 
3617   /// Fixpoint iteration update function. Will be called every time a dependence
3618   /// changed its state (and in the beginning).
3619   ChangeStatus updateImpl(Attributor &A) override {
3620     KernelInfoState StateBefore = getState();
3621 
3622     // Callback to check a read/write instruction.
3623     auto CheckRWInst = [&](Instruction &I) {
3624       // We handle calls later.
3625       if (isa<CallBase>(I))
3626         return true;
3627       // We only care about write effects.
3628       if (!I.mayWriteToMemory())
3629         return true;
3630       if (auto *SI = dyn_cast<StoreInst>(&I)) {
3631         SmallVector<const Value *> Objects;
3632         getUnderlyingObjects(SI->getPointerOperand(), Objects);
3633         if (llvm::all_of(Objects,
3634                          [](const Value *Obj) { return isa<AllocaInst>(Obj); }))
3635           return true;
3636         // Check for AAHeapToStack moved objects which must not be guarded.
3637         auto &HS = A.getAAFor<AAHeapToStack>(
3638             *this, IRPosition::function(*I.getFunction()),
3639             DepClassTy::REQUIRED);
3640         if (llvm::all_of(Objects, [&HS](const Value *Obj) {
3641               auto *CB = dyn_cast<CallBase>(Obj);
3642               if (!CB)
3643                 return false;
3644               return HS.isAssumedHeapToStack(*CB);
3645             })) {
3646           return true;
3647         }
3648       }
3649 
3650       // Insert instruction that needs guarding.
3651       SPMDCompatibilityTracker.insert(&I);
3652       return true;
3653     };
3654 
3655     bool UsedAssumedInformationInCheckRWInst = false;
3656     if (!SPMDCompatibilityTracker.isAtFixpoint())
3657       if (!A.checkForAllReadWriteInstructions(
3658               CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
3659         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3660 
3661     if (!IsKernelEntry) {
3662       updateReachingKernelEntries(A);
3663       updateParallelLevels(A);
3664 
3665       if (!ParallelLevels.isValidState())
3666         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3667     }
3668 
3669     // Callback to check a call instruction.
3670     bool AllSPMDStatesWereFixed = true;
3671     auto CheckCallInst = [&](Instruction &I) {
3672       auto &CB = cast<CallBase>(I);
3673       auto &CBAA = A.getAAFor<AAKernelInfo>(
3674           *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
3675       getState() ^= CBAA.getState();
3676       AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint();
3677       return true;
3678     };
3679 
3680     bool UsedAssumedInformationInCheckCallInst = false;
3681     if (!A.checkForAllCallLikeInstructions(
3682             CheckCallInst, *this, UsedAssumedInformationInCheckCallInst))
3683       return indicatePessimisticFixpoint();
3684 
3685     // If we haven't used any assumed information for the SPMD state we can fix
3686     // it.
3687     if (!UsedAssumedInformationInCheckRWInst &&
3688         !UsedAssumedInformationInCheckCallInst && AllSPMDStatesWereFixed)
3689       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3690 
3691     return StateBefore == getState() ? ChangeStatus::UNCHANGED
3692                                      : ChangeStatus::CHANGED;
3693   }
3694 
3695 private:
3696   /// Update info regarding reaching kernels.
3697   void updateReachingKernelEntries(Attributor &A) {
3698     auto PredCallSite = [&](AbstractCallSite ACS) {
3699       Function *Caller = ACS.getInstruction()->getFunction();
3700 
3701       assert(Caller && "Caller is nullptr");
3702 
3703       auto &CAA = A.getOrCreateAAFor<AAKernelInfo>(
3704           IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
3705       if (CAA.ReachingKernelEntries.isValidState()) {
3706         ReachingKernelEntries ^= CAA.ReachingKernelEntries;
3707         return true;
3708       }
3709 
3710       // We lost track of the caller of the associated function, any kernel
3711       // could reach now.
3712       ReachingKernelEntries.indicatePessimisticFixpoint();
3713 
3714       return true;
3715     };
3716 
3717     bool AllCallSitesKnown;
3718     if (!A.checkForAllCallSites(PredCallSite, *this,
3719                                 true /* RequireAllCallSites */,
3720                                 AllCallSitesKnown))
3721       ReachingKernelEntries.indicatePessimisticFixpoint();
3722   }
3723 
3724   /// Update info regarding parallel levels.
3725   void updateParallelLevels(Attributor &A) {
3726     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3727     OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
3728         OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
3729 
3730     auto PredCallSite = [&](AbstractCallSite ACS) {
3731       Function *Caller = ACS.getInstruction()->getFunction();
3732 
3733       assert(Caller && "Caller is nullptr");
3734 
3735       auto &CAA =
3736           A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
3737       if (CAA.ParallelLevels.isValidState()) {
3738         // Any function that is called by `__kmpc_parallel_51` will not be
3739         // folded as the parallel level in the function is updated. In order to
3740         // get it right, all the analysis would depend on the implentation. That
3741         // said, if in the future any change to the implementation, the analysis
3742         // could be wrong. As a consequence, we are just conservative here.
3743         if (Caller == Parallel51RFI.Declaration) {
3744           ParallelLevels.indicatePessimisticFixpoint();
3745           return true;
3746         }
3747 
3748         ParallelLevels ^= CAA.ParallelLevels;
3749 
3750         return true;
3751       }
3752 
3753       // We lost track of the caller of the associated function, any kernel
3754       // could reach now.
3755       ParallelLevels.indicatePessimisticFixpoint();
3756 
3757       return true;
3758     };
3759 
3760     bool AllCallSitesKnown = true;
3761     if (!A.checkForAllCallSites(PredCallSite, *this,
3762                                 true /* RequireAllCallSites */,
3763                                 AllCallSitesKnown))
3764       ParallelLevels.indicatePessimisticFixpoint();
3765   }
3766 };
3767 
3768 /// The call site kernel info abstract attribute, basically, what can we say
3769 /// about a call site with regards to the KernelInfoState. For now this simply
3770 /// forwards the information from the callee.
3771 struct AAKernelInfoCallSite : AAKernelInfo {
3772   AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
3773       : AAKernelInfo(IRP, A) {}
3774 
3775   /// See AbstractAttribute::initialize(...).
3776   void initialize(Attributor &A) override {
3777     AAKernelInfo::initialize(A);
3778 
3779     CallBase &CB = cast<CallBase>(getAssociatedValue());
3780     Function *Callee = getAssociatedFunction();
3781 
3782     // Helper to lookup an assumption string.
3783     auto HasAssumption = [](CallBase &CB, StringRef AssumptionStr) {
3784       return hasAssumption(CB, AssumptionStr);
3785     };
3786 
3787     // Check for SPMD-mode assumptions.
3788     if (HasAssumption(CB, "ompx_spmd_amenable")) {
3789       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3790       indicateOptimisticFixpoint();
3791     }
3792 
3793     // First weed out calls we do not care about, that is readonly/readnone
3794     // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
3795     // parallel region or anything else we are looking for.
3796     if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
3797       indicateOptimisticFixpoint();
3798       return;
3799     }
3800 
3801     // Next we check if we know the callee. If it is a known OpenMP function
3802     // we will handle them explicitly in the switch below. If it is not, we
3803     // will use an AAKernelInfo object on the callee to gather information and
3804     // merge that into the current state. The latter happens in the updateImpl.
3805     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3806     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
3807     if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
3808       // Unknown caller or declarations are not analyzable, we give up.
3809       if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
3810 
3811         // Unknown callees might contain parallel regions, except if they have
3812         // an appropriate assumption attached.
3813         if (!(HasAssumption(CB, "omp_no_openmp") ||
3814               HasAssumption(CB, "omp_no_parallelism")))
3815           ReachedUnknownParallelRegions.insert(&CB);
3816 
3817         // If SPMDCompatibilityTracker is not fixed, we need to give up on the
3818         // idea we can run something unknown in SPMD-mode.
3819         if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3820           SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3821           SPMDCompatibilityTracker.insert(&CB);
3822         }
3823 
3824         // We have updated the state for this unknown call properly, there won't
3825         // be any change so we indicate a fixpoint.
3826         indicateOptimisticFixpoint();
3827       }
3828       // If the callee is known and can be used in IPO, we will update the state
3829       // based on the callee state in updateImpl.
3830       return;
3831     }
3832 
3833     const unsigned int WrapperFunctionArgNo = 6;
3834     RuntimeFunction RF = It->getSecond();
3835     switch (RF) {
3836     // All the functions we know are compatible with SPMD mode.
3837     case OMPRTL___kmpc_is_spmd_exec_mode:
3838     case OMPRTL___kmpc_for_static_fini:
3839     case OMPRTL___kmpc_global_thread_num:
3840     case OMPRTL___kmpc_get_hardware_num_threads_in_block:
3841     case OMPRTL___kmpc_get_hardware_num_blocks:
3842     case OMPRTL___kmpc_single:
3843     case OMPRTL___kmpc_end_single:
3844     case OMPRTL___kmpc_master:
3845     case OMPRTL___kmpc_end_master:
3846     case OMPRTL___kmpc_barrier:
3847       break;
3848     case OMPRTL___kmpc_for_static_init_4:
3849     case OMPRTL___kmpc_for_static_init_4u:
3850     case OMPRTL___kmpc_for_static_init_8:
3851     case OMPRTL___kmpc_for_static_init_8u: {
3852       // Check the schedule and allow static schedule in SPMD mode.
3853       unsigned ScheduleArgOpNo = 2;
3854       auto *ScheduleTypeCI =
3855           dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
3856       unsigned ScheduleTypeVal =
3857           ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
3858       switch (OMPScheduleType(ScheduleTypeVal)) {
3859       case OMPScheduleType::Static:
3860       case OMPScheduleType::StaticChunked:
3861       case OMPScheduleType::Distribute:
3862       case OMPScheduleType::DistributeChunked:
3863         break;
3864       default:
3865         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3866         SPMDCompatibilityTracker.insert(&CB);
3867         break;
3868       };
3869     } break;
3870     case OMPRTL___kmpc_target_init:
3871       KernelInitCB = &CB;
3872       break;
3873     case OMPRTL___kmpc_target_deinit:
3874       KernelDeinitCB = &CB;
3875       break;
3876     case OMPRTL___kmpc_parallel_51:
3877       if (auto *ParallelRegion = dyn_cast<Function>(
3878               CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) {
3879         ReachedKnownParallelRegions.insert(ParallelRegion);
3880         break;
3881       }
3882       // The condition above should usually get the parallel region function
3883       // pointer and record it. In the off chance it doesn't we assume the
3884       // worst.
3885       ReachedUnknownParallelRegions.insert(&CB);
3886       break;
3887     case OMPRTL___kmpc_omp_task:
3888       // We do not look into tasks right now, just give up.
3889       SPMDCompatibilityTracker.insert(&CB);
3890       ReachedUnknownParallelRegions.insert(&CB);
3891       indicatePessimisticFixpoint();
3892       return;
3893     case OMPRTL___kmpc_alloc_shared:
3894     case OMPRTL___kmpc_free_shared:
3895       // Return without setting a fixpoint, to be resolved in updateImpl.
3896       return;
3897     default:
3898       // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
3899       // generally.
3900       SPMDCompatibilityTracker.insert(&CB);
3901       indicatePessimisticFixpoint();
3902       return;
3903     }
3904     // All other OpenMP runtime calls will not reach parallel regions so they
3905     // can be safely ignored for now. Since it is a known OpenMP runtime call we
3906     // have now modeled all effects and there is no need for any update.
3907     indicateOptimisticFixpoint();
3908   }
3909 
3910   ChangeStatus updateImpl(Attributor &A) override {
3911     // TODO: Once we have call site specific value information we can provide
3912     //       call site specific liveness information and then it makes
3913     //       sense to specialize attributes for call sites arguments instead of
3914     //       redirecting requests to the callee argument.
3915     Function *F = getAssociatedFunction();
3916 
3917     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3918     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
3919 
3920     // If F is not a runtime function, propagate the AAKernelInfo of the callee.
3921     if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
3922       const IRPosition &FnPos = IRPosition::function(*F);
3923       auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
3924       if (getState() == FnAA.getState())
3925         return ChangeStatus::UNCHANGED;
3926       getState() = FnAA.getState();
3927       return ChangeStatus::CHANGED;
3928     }
3929 
3930     // F is a runtime function that allocates or frees memory, check
3931     // AAHeapToStack and AAHeapToShared.
3932     KernelInfoState StateBefore = getState();
3933     assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
3934             It->getSecond() == OMPRTL___kmpc_free_shared) &&
3935            "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
3936 
3937     CallBase &CB = cast<CallBase>(getAssociatedValue());
3938 
3939     auto &HeapToStackAA = A.getAAFor<AAHeapToStack>(
3940         *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
3941     auto &HeapToSharedAA = A.getAAFor<AAHeapToShared>(
3942         *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
3943 
3944     RuntimeFunction RF = It->getSecond();
3945 
3946     switch (RF) {
3947     // If neither HeapToStack nor HeapToShared assume the call is removed,
3948     // assume SPMD incompatibility.
3949     case OMPRTL___kmpc_alloc_shared:
3950       if (!HeapToStackAA.isAssumedHeapToStack(CB) &&
3951           !HeapToSharedAA.isAssumedHeapToShared(CB))
3952         SPMDCompatibilityTracker.insert(&CB);
3953       break;
3954     case OMPRTL___kmpc_free_shared:
3955       if (!HeapToStackAA.isAssumedHeapToStackRemovedFree(CB) &&
3956           !HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB))
3957         SPMDCompatibilityTracker.insert(&CB);
3958       break;
3959     default:
3960       SPMDCompatibilityTracker.insert(&CB);
3961     }
3962 
3963     return StateBefore == getState() ? ChangeStatus::UNCHANGED
3964                                      : ChangeStatus::CHANGED;
3965   }
3966 };
3967 
3968 struct AAFoldRuntimeCall
3969     : public StateWrapper<BooleanState, AbstractAttribute> {
3970   using Base = StateWrapper<BooleanState, AbstractAttribute>;
3971 
3972   AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3973 
3974   /// Statistics are tracked as part of manifest for now.
3975   void trackStatistics() const override {}
3976 
3977   /// Create an abstract attribute biew for the position \p IRP.
3978   static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
3979                                               Attributor &A);
3980 
3981   /// See AbstractAttribute::getName()
3982   const std::string getName() const override { return "AAFoldRuntimeCall"; }
3983 
3984   /// See AbstractAttribute::getIdAddr()
3985   const char *getIdAddr() const override { return &ID; }
3986 
3987   /// This function should return true if the type of the \p AA is
3988   /// AAFoldRuntimeCall
3989   static bool classof(const AbstractAttribute *AA) {
3990     return (AA->getIdAddr() == &ID);
3991   }
3992 
3993   static const char ID;
3994 };
3995 
3996 struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
3997   AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
3998       : AAFoldRuntimeCall(IRP, A) {}
3999 
4000   /// See AbstractAttribute::getAsStr()
4001   const std::string getAsStr() const override {
4002     if (!isValidState())
4003       return "<invalid>";
4004 
4005     std::string Str("simplified value: ");
4006 
4007     if (!SimplifiedValue.hasValue())
4008       return Str + std::string("none");
4009 
4010     if (!SimplifiedValue.getValue())
4011       return Str + std::string("nullptr");
4012 
4013     if (ConstantInt *CI = dyn_cast<ConstantInt>(SimplifiedValue.getValue()))
4014       return Str + std::to_string(CI->getSExtValue());
4015 
4016     return Str + std::string("unknown");
4017   }
4018 
4019   void initialize(Attributor &A) override {
4020     if (DisableOpenMPOptFolding)
4021       indicatePessimisticFixpoint();
4022 
4023     Function *Callee = getAssociatedFunction();
4024 
4025     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4026     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4027     assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
4028            "Expected a known OpenMP runtime function");
4029 
4030     RFKind = It->getSecond();
4031 
4032     CallBase &CB = cast<CallBase>(getAssociatedValue());
4033     A.registerSimplificationCallback(
4034         IRPosition::callsite_returned(CB),
4035         [&](const IRPosition &IRP, const AbstractAttribute *AA,
4036             bool &UsedAssumedInformation) -> Optional<Value *> {
4037           assert((isValidState() || (SimplifiedValue.hasValue() &&
4038                                      SimplifiedValue.getValue() == nullptr)) &&
4039                  "Unexpected invalid state!");
4040 
4041           if (!isAtFixpoint()) {
4042             UsedAssumedInformation = true;
4043             if (AA)
4044               A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
4045           }
4046           return SimplifiedValue;
4047         });
4048   }
4049 
4050   ChangeStatus updateImpl(Attributor &A) override {
4051     ChangeStatus Changed = ChangeStatus::UNCHANGED;
4052     switch (RFKind) {
4053     case OMPRTL___kmpc_is_spmd_exec_mode:
4054       Changed |= foldIsSPMDExecMode(A);
4055       break;
4056     case OMPRTL___kmpc_is_generic_main_thread_id:
4057       Changed |= foldIsGenericMainThread(A);
4058       break;
4059     case OMPRTL___kmpc_parallel_level:
4060       Changed |= foldParallelLevel(A);
4061       break;
4062     case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4063       Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
4064       break;
4065     case OMPRTL___kmpc_get_hardware_num_blocks:
4066       Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
4067       break;
4068     default:
4069       llvm_unreachable("Unhandled OpenMP runtime function!");
4070     }
4071 
4072     return Changed;
4073   }
4074 
4075   ChangeStatus manifest(Attributor &A) override {
4076     ChangeStatus Changed = ChangeStatus::UNCHANGED;
4077 
4078     if (SimplifiedValue.hasValue() && SimplifiedValue.getValue()) {
4079       Instruction &I = *getCtxI();
4080       A.changeValueAfterManifest(I, **SimplifiedValue);
4081       A.deleteAfterManifest(I);
4082 
4083       CallBase *CB = dyn_cast<CallBase>(&I);
4084       auto Remark = [&](OptimizationRemark OR) {
4085         if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
4086           return OR << "Replacing OpenMP runtime call "
4087                     << CB->getCalledFunction()->getName() << " with "
4088                     << ore::NV("FoldedValue", C->getZExtValue()) << ".";
4089         else
4090           return OR << "Replacing OpenMP runtime call "
4091                     << CB->getCalledFunction()->getName() << ".";
4092       };
4093 
4094       if (CB && EnableVerboseRemarks)
4095         A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
4096 
4097       LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
4098                         << **SimplifiedValue << "\n");
4099 
4100       Changed = ChangeStatus::CHANGED;
4101     }
4102 
4103     return Changed;
4104   }
4105 
4106   ChangeStatus indicatePessimisticFixpoint() override {
4107     SimplifiedValue = nullptr;
4108     return AAFoldRuntimeCall::indicatePessimisticFixpoint();
4109   }
4110 
4111 private:
4112   /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
4113   ChangeStatus foldIsSPMDExecMode(Attributor &A) {
4114     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4115 
4116     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4117     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4118     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4119         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4120 
4121     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4122       return indicatePessimisticFixpoint();
4123 
4124     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4125       auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4126                                           DepClassTy::REQUIRED);
4127 
4128       if (!AA.isValidState()) {
4129         SimplifiedValue = nullptr;
4130         return indicatePessimisticFixpoint();
4131       }
4132 
4133       if (AA.SPMDCompatibilityTracker.isAssumed()) {
4134         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4135           ++KnownSPMDCount;
4136         else
4137           ++AssumedSPMDCount;
4138       } else {
4139         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4140           ++KnownNonSPMDCount;
4141         else
4142           ++AssumedNonSPMDCount;
4143       }
4144     }
4145 
4146     if ((AssumedSPMDCount + KnownSPMDCount) &&
4147         (AssumedNonSPMDCount + KnownNonSPMDCount))
4148       return indicatePessimisticFixpoint();
4149 
4150     auto &Ctx = getAnchorValue().getContext();
4151     if (KnownSPMDCount || AssumedSPMDCount) {
4152       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4153              "Expected only SPMD kernels!");
4154       // All reaching kernels are in SPMD mode. Update all function calls to
4155       // __kmpc_is_spmd_exec_mode to 1.
4156       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4157     } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
4158       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4159              "Expected only non-SPMD kernels!");
4160       // All reaching kernels are in non-SPMD mode. Update all function
4161       // calls to __kmpc_is_spmd_exec_mode to 0.
4162       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
4163     } else {
4164       // We have empty reaching kernels, therefore we cannot tell if the
4165       // associated call site can be folded. At this moment, SimplifiedValue
4166       // must be none.
4167       assert(!SimplifiedValue.hasValue() && "SimplifiedValue should be none");
4168     }
4169 
4170     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4171                                                     : ChangeStatus::CHANGED;
4172   }
4173 
4174   /// Fold __kmpc_is_generic_main_thread_id into a constant if possible.
4175   ChangeStatus foldIsGenericMainThread(Attributor &A) {
4176     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4177 
4178     CallBase &CB = cast<CallBase>(getAssociatedValue());
4179     Function *F = CB.getFunction();
4180     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
4181         *this, IRPosition::function(*F), DepClassTy::REQUIRED);
4182 
4183     if (!ExecutionDomainAA.isValidState())
4184       return indicatePessimisticFixpoint();
4185 
4186     auto &Ctx = getAnchorValue().getContext();
4187     if (ExecutionDomainAA.isExecutedByInitialThreadOnly(CB))
4188       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4189     else
4190       return indicatePessimisticFixpoint();
4191 
4192     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4193                                                     : ChangeStatus::CHANGED;
4194   }
4195 
4196   /// Fold __kmpc_parallel_level into a constant if possible.
4197   ChangeStatus foldParallelLevel(Attributor &A) {
4198     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4199 
4200     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4201         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4202 
4203     if (!CallerKernelInfoAA.ParallelLevels.isValidState())
4204       return indicatePessimisticFixpoint();
4205 
4206     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4207       return indicatePessimisticFixpoint();
4208 
4209     if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
4210       assert(!SimplifiedValue.hasValue() &&
4211              "SimplifiedValue should keep none at this point");
4212       return ChangeStatus::UNCHANGED;
4213     }
4214 
4215     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4216     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4217     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4218       auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4219                                           DepClassTy::REQUIRED);
4220       if (!AA.SPMDCompatibilityTracker.isValidState())
4221         return indicatePessimisticFixpoint();
4222 
4223       if (AA.SPMDCompatibilityTracker.isAssumed()) {
4224         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4225           ++KnownSPMDCount;
4226         else
4227           ++AssumedSPMDCount;
4228       } else {
4229         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4230           ++KnownNonSPMDCount;
4231         else
4232           ++AssumedNonSPMDCount;
4233       }
4234     }
4235 
4236     if ((AssumedSPMDCount + KnownSPMDCount) &&
4237         (AssumedNonSPMDCount + KnownNonSPMDCount))
4238       return indicatePessimisticFixpoint();
4239 
4240     auto &Ctx = getAnchorValue().getContext();
4241     // If the caller can only be reached by SPMD kernel entries, the parallel
4242     // level is 1. Similarly, if the caller can only be reached by non-SPMD
4243     // kernel entries, it is 0.
4244     if (AssumedSPMDCount || KnownSPMDCount) {
4245       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4246              "Expected only SPMD kernels!");
4247       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
4248     } else {
4249       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4250              "Expected only non-SPMD kernels!");
4251       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
4252     }
4253     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4254                                                     : ChangeStatus::CHANGED;
4255   }
4256 
4257   ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
4258     // Specialize only if all the calls agree with the attribute constant value
4259     int32_t CurrentAttrValue = -1;
4260     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4261 
4262     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4263         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4264 
4265     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4266       return indicatePessimisticFixpoint();
4267 
4268     // Iterate over the kernels that reach this function
4269     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4270       int32_t NextAttrVal = -1;
4271       if (K->hasFnAttribute(Attr))
4272         NextAttrVal =
4273             std::stoi(K->getFnAttribute(Attr).getValueAsString().str());
4274 
4275       if (NextAttrVal == -1 ||
4276           (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
4277         return indicatePessimisticFixpoint();
4278       CurrentAttrValue = NextAttrVal;
4279     }
4280 
4281     if (CurrentAttrValue != -1) {
4282       auto &Ctx = getAnchorValue().getContext();
4283       SimplifiedValue =
4284           ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
4285     }
4286     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4287                                                     : ChangeStatus::CHANGED;
4288   }
4289 
4290   /// An optional value the associated value is assumed to fold to. That is, we
4291   /// assume the associated value (which is a call) can be replaced by this
4292   /// simplified value.
4293   Optional<Value *> SimplifiedValue;
4294 
4295   /// The runtime function kind of the callee of the associated call site.
4296   RuntimeFunction RFKind;
4297 };
4298 
4299 } // namespace
4300 
4301 /// Register folding callsite
4302 void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
4303   auto &RFI = OMPInfoCache.RFIs[RF];
4304   RFI.foreachUse(SCC, [&](Use &U, Function &F) {
4305     CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
4306     if (!CI)
4307       return false;
4308     A.getOrCreateAAFor<AAFoldRuntimeCall>(
4309         IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
4310         DepClassTy::NONE, /* ForceUpdate */ false,
4311         /* UpdateAfterInit */ false);
4312     return false;
4313   });
4314 }
4315 
4316 void OpenMPOpt::registerAAs(bool IsModulePass) {
4317   if (SCC.empty())
4318 
4319     return;
4320   if (IsModulePass) {
4321     // Ensure we create the AAKernelInfo AAs first and without triggering an
4322     // update. This will make sure we register all value simplification
4323     // callbacks before any other AA has the chance to create an AAValueSimplify
4324     // or similar.
4325     for (Function *Kernel : OMPInfoCache.Kernels)
4326       A.getOrCreateAAFor<AAKernelInfo>(
4327           IRPosition::function(*Kernel), /* QueryingAA */ nullptr,
4328           DepClassTy::NONE, /* ForceUpdate */ false,
4329           /* UpdateAfterInit */ false);
4330 
4331     registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id);
4332     registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
4333     registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
4334     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
4335     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
4336   }
4337 
4338   // Create CallSite AA for all Getters.
4339   for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
4340     auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
4341 
4342     auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
4343 
4344     auto CreateAA = [&](Use &U, Function &Caller) {
4345       CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
4346       if (!CI)
4347         return false;
4348 
4349       auto &CB = cast<CallBase>(*CI);
4350 
4351       IRPosition CBPos = IRPosition::callsite_function(CB);
4352       A.getOrCreateAAFor<AAICVTracker>(CBPos);
4353       return false;
4354     };
4355 
4356     GetterRFI.foreachUse(SCC, CreateAA);
4357   }
4358   auto &GlobalizationRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4359   auto CreateAA = [&](Use &U, Function &F) {
4360     A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
4361     return false;
4362   };
4363   if (!DisableOpenMPOptDeglobalization)
4364     GlobalizationRFI.foreachUse(SCC, CreateAA);
4365 
4366   // Create an ExecutionDomain AA for every function and a HeapToStack AA for
4367   // every function if there is a device kernel.
4368   if (!isOpenMPDevice(M))
4369     return;
4370 
4371   for (auto *F : SCC) {
4372     if (F->isDeclaration())
4373       continue;
4374 
4375     A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F));
4376     if (!DisableOpenMPOptDeglobalization)
4377       A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F));
4378 
4379     for (auto &I : instructions(*F)) {
4380       if (auto *LI = dyn_cast<LoadInst>(&I)) {
4381         bool UsedAssumedInformation = false;
4382         A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
4383                                UsedAssumedInformation);
4384       }
4385     }
4386   }
4387 }
4388 
4389 const char AAICVTracker::ID = 0;
4390 const char AAKernelInfo::ID = 0;
4391 const char AAExecutionDomain::ID = 0;
4392 const char AAHeapToShared::ID = 0;
4393 const char AAFoldRuntimeCall::ID = 0;
4394 
4395 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
4396                                               Attributor &A) {
4397   AAICVTracker *AA = nullptr;
4398   switch (IRP.getPositionKind()) {
4399   case IRPosition::IRP_INVALID:
4400   case IRPosition::IRP_FLOAT:
4401   case IRPosition::IRP_ARGUMENT:
4402   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4403     llvm_unreachable("ICVTracker can only be created for function position!");
4404   case IRPosition::IRP_RETURNED:
4405     AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
4406     break;
4407   case IRPosition::IRP_CALL_SITE_RETURNED:
4408     AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
4409     break;
4410   case IRPosition::IRP_CALL_SITE:
4411     AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
4412     break;
4413   case IRPosition::IRP_FUNCTION:
4414     AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
4415     break;
4416   }
4417 
4418   return *AA;
4419 }
4420 
4421 AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP,
4422                                                         Attributor &A) {
4423   AAExecutionDomainFunction *AA = nullptr;
4424   switch (IRP.getPositionKind()) {
4425   case IRPosition::IRP_INVALID:
4426   case IRPosition::IRP_FLOAT:
4427   case IRPosition::IRP_ARGUMENT:
4428   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4429   case IRPosition::IRP_RETURNED:
4430   case IRPosition::IRP_CALL_SITE_RETURNED:
4431   case IRPosition::IRP_CALL_SITE:
4432     llvm_unreachable(
4433         "AAExecutionDomain can only be created for function position!");
4434   case IRPosition::IRP_FUNCTION:
4435     AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
4436     break;
4437   }
4438 
4439   return *AA;
4440 }
4441 
4442 AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
4443                                                   Attributor &A) {
4444   AAHeapToSharedFunction *AA = nullptr;
4445   switch (IRP.getPositionKind()) {
4446   case IRPosition::IRP_INVALID:
4447   case IRPosition::IRP_FLOAT:
4448   case IRPosition::IRP_ARGUMENT:
4449   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4450   case IRPosition::IRP_RETURNED:
4451   case IRPosition::IRP_CALL_SITE_RETURNED:
4452   case IRPosition::IRP_CALL_SITE:
4453     llvm_unreachable(
4454         "AAHeapToShared can only be created for function position!");
4455   case IRPosition::IRP_FUNCTION:
4456     AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
4457     break;
4458   }
4459 
4460   return *AA;
4461 }
4462 
4463 AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
4464                                               Attributor &A) {
4465   AAKernelInfo *AA = nullptr;
4466   switch (IRP.getPositionKind()) {
4467   case IRPosition::IRP_INVALID:
4468   case IRPosition::IRP_FLOAT:
4469   case IRPosition::IRP_ARGUMENT:
4470   case IRPosition::IRP_RETURNED:
4471   case IRPosition::IRP_CALL_SITE_RETURNED:
4472   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4473     llvm_unreachable("KernelInfo can only be created for function position!");
4474   case IRPosition::IRP_CALL_SITE:
4475     AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
4476     break;
4477   case IRPosition::IRP_FUNCTION:
4478     AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
4479     break;
4480   }
4481 
4482   return *AA;
4483 }
4484 
4485 AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
4486                                                         Attributor &A) {
4487   AAFoldRuntimeCall *AA = nullptr;
4488   switch (IRP.getPositionKind()) {
4489   case IRPosition::IRP_INVALID:
4490   case IRPosition::IRP_FLOAT:
4491   case IRPosition::IRP_ARGUMENT:
4492   case IRPosition::IRP_RETURNED:
4493   case IRPosition::IRP_FUNCTION:
4494   case IRPosition::IRP_CALL_SITE:
4495   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4496     llvm_unreachable("KernelInfo can only be created for call site position!");
4497   case IRPosition::IRP_CALL_SITE_RETURNED:
4498     AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
4499     break;
4500   }
4501 
4502   return *AA;
4503 }
4504 
4505 PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
4506   if (!containsOpenMP(M))
4507     return PreservedAnalyses::all();
4508   if (DisableOpenMPOptimizations)
4509     return PreservedAnalyses::all();
4510 
4511   FunctionAnalysisManager &FAM =
4512       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
4513   KernelSet Kernels = getDeviceKernels(M);
4514 
4515   auto IsCalled = [&](Function &F) {
4516     if (Kernels.contains(&F))
4517       return true;
4518     for (const User *U : F.users())
4519       if (!isa<BlockAddress>(U))
4520         return true;
4521     return false;
4522   };
4523 
4524   auto EmitRemark = [&](Function &F) {
4525     auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
4526     ORE.emit([&]() {
4527       OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
4528       return ORA << "Could not internalize function. "
4529                  << "Some optimizations may not be possible. [OMP140]";
4530     });
4531   };
4532 
4533   // Create internal copies of each function if this is a kernel Module. This
4534   // allows iterprocedural passes to see every call edge.
4535   DenseMap<Function *, Function *> InternalizedMap;
4536   if (isOpenMPDevice(M)) {
4537     SmallPtrSet<Function *, 16> InternalizeFns;
4538     for (Function &F : M)
4539       if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
4540           !DisableInternalization) {
4541         if (Attributor::isInternalizable(F)) {
4542           InternalizeFns.insert(&F);
4543         } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
4544           EmitRemark(F);
4545         }
4546       }
4547 
4548     Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
4549   }
4550 
4551   // Look at every function in the Module unless it was internalized.
4552   SmallVector<Function *, 16> SCC;
4553   for (Function &F : M)
4554     if (!F.isDeclaration() && !InternalizedMap.lookup(&F))
4555       SCC.push_back(&F);
4556 
4557   if (SCC.empty())
4558     return PreservedAnalyses::all();
4559 
4560   AnalysisGetter AG(FAM);
4561 
4562   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
4563     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
4564   };
4565 
4566   BumpPtrAllocator Allocator;
4567   CallGraphUpdater CGUpdater;
4568 
4569   SetVector<Function *> Functions(SCC.begin(), SCC.end());
4570   OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels);
4571 
4572   unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4573   Attributor A(Functions, InfoCache, CGUpdater, nullptr, true, false,
4574                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4575 
4576   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4577   bool Changed = OMPOpt.run(true);
4578 
4579   // Optionally inline device functions for potentially better performance.
4580   if (AlwaysInlineDeviceFunctions && isOpenMPDevice(M))
4581     for (Function &F : M)
4582       if (!F.isDeclaration() && !Kernels.contains(&F) &&
4583           !F.hasFnAttribute(Attribute::NoInline))
4584         F.addFnAttr(Attribute::AlwaysInline);
4585 
4586   if (PrintModuleAfterOptimizations)
4587     LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
4588 
4589   if (Changed)
4590     return PreservedAnalyses::none();
4591 
4592   return PreservedAnalyses::all();
4593 }
4594 
4595 PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
4596                                           CGSCCAnalysisManager &AM,
4597                                           LazyCallGraph &CG,
4598                                           CGSCCUpdateResult &UR) {
4599   if (!containsOpenMP(*C.begin()->getFunction().getParent()))
4600     return PreservedAnalyses::all();
4601   if (DisableOpenMPOptimizations)
4602     return PreservedAnalyses::all();
4603 
4604   SmallVector<Function *, 16> SCC;
4605   // If there are kernels in the module, we have to run on all SCC's.
4606   for (LazyCallGraph::Node &N : C) {
4607     Function *Fn = &N.getFunction();
4608     SCC.push_back(Fn);
4609   }
4610 
4611   if (SCC.empty())
4612     return PreservedAnalyses::all();
4613 
4614   Module &M = *C.begin()->getFunction().getParent();
4615 
4616   KernelSet Kernels = getDeviceKernels(M);
4617 
4618   FunctionAnalysisManager &FAM =
4619       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
4620 
4621   AnalysisGetter AG(FAM);
4622 
4623   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
4624     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
4625   };
4626 
4627   BumpPtrAllocator Allocator;
4628   CallGraphUpdater CGUpdater;
4629   CGUpdater.initialize(CG, C, AM, UR);
4630 
4631   SetVector<Function *> Functions(SCC.begin(), SCC.end());
4632   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
4633                                 /*CGSCC*/ Functions, Kernels);
4634 
4635   unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4636   Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
4637                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4638 
4639   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4640   bool Changed = OMPOpt.run(false);
4641 
4642   if (PrintModuleAfterOptimizations)
4643     LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
4644 
4645   if (Changed)
4646     return PreservedAnalyses::none();
4647 
4648   return PreservedAnalyses::all();
4649 }
4650 
4651 namespace {
4652 
4653 struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass {
4654   CallGraphUpdater CGUpdater;
4655   static char ID;
4656 
4657   OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) {
4658     initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry());
4659   }
4660 
4661   void getAnalysisUsage(AnalysisUsage &AU) const override {
4662     CallGraphSCCPass::getAnalysisUsage(AU);
4663   }
4664 
4665   bool runOnSCC(CallGraphSCC &CGSCC) override {
4666     if (!containsOpenMP(CGSCC.getCallGraph().getModule()))
4667       return false;
4668     if (DisableOpenMPOptimizations || skipSCC(CGSCC))
4669       return false;
4670 
4671     SmallVector<Function *, 16> SCC;
4672     // If there are kernels in the module, we have to run on all SCC's.
4673     for (CallGraphNode *CGN : CGSCC) {
4674       Function *Fn = CGN->getFunction();
4675       if (!Fn || Fn->isDeclaration())
4676         continue;
4677       SCC.push_back(Fn);
4678     }
4679 
4680     if (SCC.empty())
4681       return false;
4682 
4683     Module &M = CGSCC.getCallGraph().getModule();
4684     KernelSet Kernels = getDeviceKernels(M);
4685 
4686     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
4687     CGUpdater.initialize(CG, CGSCC);
4688 
4689     // Maintain a map of functions to avoid rebuilding the ORE
4690     DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap;
4691     auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
4692       std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
4693       if (!ORE)
4694         ORE = std::make_unique<OptimizationRemarkEmitter>(F);
4695       return *ORE;
4696     };
4697 
4698     AnalysisGetter AG;
4699     SetVector<Function *> Functions(SCC.begin(), SCC.end());
4700     BumpPtrAllocator Allocator;
4701     OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG,
4702                                   Allocator,
4703                                   /*CGSCC*/ Functions, Kernels);
4704 
4705     unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4706     Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
4707                  MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4708 
4709     OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4710     bool Result = OMPOpt.run(false);
4711 
4712     if (PrintModuleAfterOptimizations)
4713       LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
4714 
4715     return Result;
4716   }
4717 
4718   bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
4719 };
4720 
4721 } // end anonymous namespace
4722 
4723 KernelSet llvm::omp::getDeviceKernels(Module &M) {
4724   // TODO: Create a more cross-platform way of determining device kernels.
4725   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
4726   KernelSet Kernels;
4727 
4728   if (!MD)
4729     return Kernels;
4730 
4731   for (auto *Op : MD->operands()) {
4732     if (Op->getNumOperands() < 2)
4733       continue;
4734     MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
4735     if (!KindID || KindID->getString() != "kernel")
4736       continue;
4737 
4738     Function *KernelFn =
4739         mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
4740     if (!KernelFn)
4741       continue;
4742 
4743     ++NumOpenMPTargetRegionKernels;
4744 
4745     Kernels.insert(KernelFn);
4746   }
4747 
4748   return Kernels;
4749 }
4750 
4751 bool llvm::omp::containsOpenMP(Module &M) {
4752   Metadata *MD = M.getModuleFlag("openmp");
4753   if (!MD)
4754     return false;
4755 
4756   return true;
4757 }
4758 
4759 bool llvm::omp::isOpenMPDevice(Module &M) {
4760   Metadata *MD = M.getModuleFlag("openmp-device");
4761   if (!MD)
4762     return false;
4763 
4764   return true;
4765 }
4766 
4767 char OpenMPOptCGSCCLegacyPass::ID = 0;
4768 
4769 INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
4770                       "OpenMP specific optimizations", false, false)
4771 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
4772 INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
4773                     "OpenMP specific optimizations", false, false)
4774 
4775 Pass *llvm::createOpenMPOptCGSCCLegacyPass() {
4776   return new OpenMPOptCGSCCLegacyPass();
4777 }
4778