1 //===-- IPO/OpenMPOpt.cpp - Collection of OpenMP specific optimizations ---===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // OpenMP specific optimizations:
10 //
11 // - Deduplication of runtime calls, e.g., omp_get_thread_num.
12 // - Replacing globalized device memory with stack memory.
13 // - Replacing globalized device memory with shared memory.
14 // - Parallel region merging.
15 // - Transforming generic-mode device kernels to SPMD mode.
16 // - Specializing the state machine for generic-mode device kernels.
17 //
18 //===----------------------------------------------------------------------===//
19 
20 #include "llvm/Transforms/IPO/OpenMPOpt.h"
21 
22 #include "llvm/ADT/EnumeratedArray.h"
23 #include "llvm/ADT/PostOrderIterator.h"
24 #include "llvm/ADT/Statistic.h"
25 #include "llvm/Analysis/CallGraph.h"
26 #include "llvm/Analysis/CallGraphSCCPass.h"
27 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
28 #include "llvm/Analysis/ValueTracking.h"
29 #include "llvm/Frontend/OpenMP/OMPConstants.h"
30 #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
31 #include "llvm/IR/Assumptions.h"
32 #include "llvm/IR/DiagnosticInfo.h"
33 #include "llvm/IR/GlobalValue.h"
34 #include "llvm/IR/Instruction.h"
35 #include "llvm/IR/IntrinsicInst.h"
36 #include "llvm/IR/IntrinsicsAMDGPU.h"
37 #include "llvm/IR/IntrinsicsNVPTX.h"
38 #include "llvm/InitializePasses.h"
39 #include "llvm/Support/CommandLine.h"
40 #include "llvm/Transforms/IPO.h"
41 #include "llvm/Transforms/IPO/Attributor.h"
42 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
43 #include "llvm/Transforms/Utils/CallGraphUpdater.h"
44 #include "llvm/Transforms/Utils/CodeExtractor.h"
45 
46 using namespace llvm;
47 using namespace omp;
48 
49 #define DEBUG_TYPE "openmp-opt"
50 
51 static cl::opt<bool> DisableOpenMPOptimizations(
52     "openmp-opt-disable", cl::ZeroOrMore,
53     cl::desc("Disable OpenMP specific optimizations."), cl::Hidden,
54     cl::init(false));
55 
56 static cl::opt<bool> EnableParallelRegionMerging(
57     "openmp-opt-enable-merging", cl::ZeroOrMore,
58     cl::desc("Enable the OpenMP region merging optimization."), cl::Hidden,
59     cl::init(false));
60 
61 static cl::opt<bool>
62     DisableInternalization("openmp-opt-disable-internalization", cl::ZeroOrMore,
63                            cl::desc("Disable function internalization."),
64                            cl::Hidden, cl::init(false));
65 
66 static cl::opt<bool> PrintICVValues("openmp-print-icv-values", cl::init(false),
67                                     cl::Hidden);
68 static cl::opt<bool> PrintOpenMPKernels("openmp-print-gpu-kernels",
69                                         cl::init(false), cl::Hidden);
70 
71 static cl::opt<bool> HideMemoryTransferLatency(
72     "openmp-hide-memory-transfer-latency",
73     cl::desc("[WIP] Tries to hide the latency of host to device memory"
74              " transfers"),
75     cl::Hidden, cl::init(false));
76 
77 static cl::opt<bool> DisableOpenMPOptDeglobalization(
78     "openmp-opt-disable-deglobalization", cl::ZeroOrMore,
79     cl::desc("Disable OpenMP optimizations involving deglobalization."),
80     cl::Hidden, cl::init(false));
81 
82 static cl::opt<bool> DisableOpenMPOptSPMDization(
83     "openmp-opt-disable-spmdization", cl::ZeroOrMore,
84     cl::desc("Disable OpenMP optimizations involving SPMD-ization."),
85     cl::Hidden, cl::init(false));
86 
87 static cl::opt<bool> DisableOpenMPOptFolding(
88     "openmp-opt-disable-folding", cl::ZeroOrMore,
89     cl::desc("Disable OpenMP optimizations involving folding."), cl::Hidden,
90     cl::init(false));
91 
92 static cl::opt<bool> DisableOpenMPOptStateMachineRewrite(
93     "openmp-opt-disable-state-machine-rewrite", cl::ZeroOrMore,
94     cl::desc("Disable OpenMP optimizations that replace the state machine."),
95     cl::Hidden, cl::init(false));
96 
97 static cl::opt<bool> PrintModuleAfterOptimizations(
98     "openmp-opt-print-module", cl::ZeroOrMore,
99     cl::desc("Print the current module after OpenMP optimizations."),
100     cl::Hidden, cl::init(false));
101 
102 static cl::opt<bool> AlwaysInlineDeviceFunctions(
103     "openmp-opt-inline-device", cl::ZeroOrMore,
104     cl::desc("Inline all applicible functions on the device."), cl::Hidden,
105     cl::init(false));
106 
107 static cl::opt<bool>
108     EnableVerboseRemarks("openmp-opt-verbose-remarks", cl::ZeroOrMore,
109                          cl::desc("Enables more verbose remarks."), cl::Hidden,
110                          cl::init(false));
111 
112 STATISTIC(NumOpenMPRuntimeCallsDeduplicated,
113           "Number of OpenMP runtime calls deduplicated");
114 STATISTIC(NumOpenMPParallelRegionsDeleted,
115           "Number of OpenMP parallel regions deleted");
116 STATISTIC(NumOpenMPRuntimeFunctionsIdentified,
117           "Number of OpenMP runtime functions identified");
118 STATISTIC(NumOpenMPRuntimeFunctionUsesIdentified,
119           "Number of OpenMP runtime function uses identified");
120 STATISTIC(NumOpenMPTargetRegionKernels,
121           "Number of OpenMP target region entry points (=kernels) identified");
122 STATISTIC(NumOpenMPTargetRegionKernelsSPMD,
123           "Number of OpenMP target region entry points (=kernels) executed in "
124           "SPMD-mode instead of generic-mode");
125 STATISTIC(NumOpenMPTargetRegionKernelsWithoutStateMachine,
126           "Number of OpenMP target region entry points (=kernels) executed in "
127           "generic-mode without a state machines");
128 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback,
129           "Number of OpenMP target region entry points (=kernels) executed in "
130           "generic-mode with customized state machines with fallback");
131 STATISTIC(NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback,
132           "Number of OpenMP target region entry points (=kernels) executed in "
133           "generic-mode with customized state machines without fallback");
134 STATISTIC(
135     NumOpenMPParallelRegionsReplacedInGPUStateMachine,
136     "Number of OpenMP parallel regions replaced with ID in GPU state machines");
137 STATISTIC(NumOpenMPParallelRegionsMerged,
138           "Number of OpenMP parallel regions merged");
139 STATISTIC(NumBytesMovedToSharedMemory,
140           "Amount of memory pushed to shared memory");
141 
142 #if !defined(NDEBUG)
143 static constexpr auto TAG = "[" DEBUG_TYPE "]";
144 #endif
145 
146 namespace {
147 
148 enum class AddressSpace : unsigned {
149   Generic = 0,
150   Global = 1,
151   Shared = 3,
152   Constant = 4,
153   Local = 5,
154 };
155 
156 struct AAHeapToShared;
157 
158 struct AAICVTracker;
159 
160 /// OpenMP specific information. For now, stores RFIs and ICVs also needed for
161 /// Attributor runs.
162 struct OMPInformationCache : public InformationCache {
163   OMPInformationCache(Module &M, AnalysisGetter &AG,
164                       BumpPtrAllocator &Allocator, SetVector<Function *> &CGSCC,
165                       SmallPtrSetImpl<Kernel> &Kernels)
166       : InformationCache(M, AG, Allocator, &CGSCC), OMPBuilder(M),
167         Kernels(Kernels) {
168 
169     OMPBuilder.initialize();
170     initializeRuntimeFunctions();
171     initializeInternalControlVars();
172   }
173 
174   /// Generic information that describes an internal control variable.
175   struct InternalControlVarInfo {
176     /// The kind, as described by InternalControlVar enum.
177     InternalControlVar Kind;
178 
179     /// The name of the ICV.
180     StringRef Name;
181 
182     /// Environment variable associated with this ICV.
183     StringRef EnvVarName;
184 
185     /// Initial value kind.
186     ICVInitValue InitKind;
187 
188     /// Initial value.
189     ConstantInt *InitValue;
190 
191     /// Setter RTL function associated with this ICV.
192     RuntimeFunction Setter;
193 
194     /// Getter RTL function associated with this ICV.
195     RuntimeFunction Getter;
196 
197     /// RTL Function corresponding to the override clause of this ICV
198     RuntimeFunction Clause;
199   };
200 
201   /// Generic information that describes a runtime function
202   struct RuntimeFunctionInfo {
203 
204     /// The kind, as described by the RuntimeFunction enum.
205     RuntimeFunction Kind;
206 
207     /// The name of the function.
208     StringRef Name;
209 
210     /// Flag to indicate a variadic function.
211     bool IsVarArg;
212 
213     /// The return type of the function.
214     Type *ReturnType;
215 
216     /// The argument types of the function.
217     SmallVector<Type *, 8> ArgumentTypes;
218 
219     /// The declaration if available.
220     Function *Declaration = nullptr;
221 
222     /// Uses of this runtime function per function containing the use.
223     using UseVector = SmallVector<Use *, 16>;
224 
225     /// Clear UsesMap for runtime function.
226     void clearUsesMap() { UsesMap.clear(); }
227 
228     /// Boolean conversion that is true if the runtime function was found.
229     operator bool() const { return Declaration; }
230 
231     /// Return the vector of uses in function \p F.
232     UseVector &getOrCreateUseVector(Function *F) {
233       std::shared_ptr<UseVector> &UV = UsesMap[F];
234       if (!UV)
235         UV = std::make_shared<UseVector>();
236       return *UV;
237     }
238 
239     /// Return the vector of uses in function \p F or `nullptr` if there are
240     /// none.
241     const UseVector *getUseVector(Function &F) const {
242       auto I = UsesMap.find(&F);
243       if (I != UsesMap.end())
244         return I->second.get();
245       return nullptr;
246     }
247 
248     /// Return how many functions contain uses of this runtime function.
249     size_t getNumFunctionsWithUses() const { return UsesMap.size(); }
250 
251     /// Return the number of arguments (or the minimal number for variadic
252     /// functions).
253     size_t getNumArgs() const { return ArgumentTypes.size(); }
254 
255     /// Run the callback \p CB on each use and forget the use if the result is
256     /// true. The callback will be fed the function in which the use was
257     /// encountered as second argument.
258     void foreachUse(SmallVectorImpl<Function *> &SCC,
259                     function_ref<bool(Use &, Function &)> CB) {
260       for (Function *F : SCC)
261         foreachUse(CB, F);
262     }
263 
264     /// Run the callback \p CB on each use within the function \p F and forget
265     /// the use if the result is true.
266     void foreachUse(function_ref<bool(Use &, Function &)> CB, Function *F) {
267       SmallVector<unsigned, 8> ToBeDeleted;
268       ToBeDeleted.clear();
269 
270       unsigned Idx = 0;
271       UseVector &UV = getOrCreateUseVector(F);
272 
273       for (Use *U : UV) {
274         if (CB(*U, *F))
275           ToBeDeleted.push_back(Idx);
276         ++Idx;
277       }
278 
279       // Remove the to-be-deleted indices in reverse order as prior
280       // modifications will not modify the smaller indices.
281       while (!ToBeDeleted.empty()) {
282         unsigned Idx = ToBeDeleted.pop_back_val();
283         UV[Idx] = UV.back();
284         UV.pop_back();
285       }
286     }
287 
288   private:
289     /// Map from functions to all uses of this runtime function contained in
290     /// them.
291     DenseMap<Function *, std::shared_ptr<UseVector>> UsesMap;
292 
293   public:
294     /// Iterators for the uses of this runtime function.
295     decltype(UsesMap)::iterator begin() { return UsesMap.begin(); }
296     decltype(UsesMap)::iterator end() { return UsesMap.end(); }
297   };
298 
299   /// An OpenMP-IR-Builder instance
300   OpenMPIRBuilder OMPBuilder;
301 
302   /// Map from runtime function kind to the runtime function description.
303   EnumeratedArray<RuntimeFunctionInfo, RuntimeFunction,
304                   RuntimeFunction::OMPRTL___last>
305       RFIs;
306 
307   /// Map from function declarations/definitions to their runtime enum type.
308   DenseMap<Function *, RuntimeFunction> RuntimeFunctionIDMap;
309 
310   /// Map from ICV kind to the ICV description.
311   EnumeratedArray<InternalControlVarInfo, InternalControlVar,
312                   InternalControlVar::ICV___last>
313       ICVs;
314 
315   /// Helper to initialize all internal control variable information for those
316   /// defined in OMPKinds.def.
317   void initializeInternalControlVars() {
318 #define ICV_RT_SET(_Name, RTL)                                                 \
319   {                                                                            \
320     auto &ICV = ICVs[_Name];                                                   \
321     ICV.Setter = RTL;                                                          \
322   }
323 #define ICV_RT_GET(Name, RTL)                                                  \
324   {                                                                            \
325     auto &ICV = ICVs[Name];                                                    \
326     ICV.Getter = RTL;                                                          \
327   }
328 #define ICV_DATA_ENV(Enum, _Name, _EnvVarName, Init)                           \
329   {                                                                            \
330     auto &ICV = ICVs[Enum];                                                    \
331     ICV.Name = _Name;                                                          \
332     ICV.Kind = Enum;                                                           \
333     ICV.InitKind = Init;                                                       \
334     ICV.EnvVarName = _EnvVarName;                                              \
335     switch (ICV.InitKind) {                                                    \
336     case ICV_IMPLEMENTATION_DEFINED:                                           \
337       ICV.InitValue = nullptr;                                                 \
338       break;                                                                   \
339     case ICV_ZERO:                                                             \
340       ICV.InitValue = ConstantInt::get(                                        \
341           Type::getInt32Ty(OMPBuilder.Int32->getContext()), 0);                \
342       break;                                                                   \
343     case ICV_FALSE:                                                            \
344       ICV.InitValue = ConstantInt::getFalse(OMPBuilder.Int1->getContext());    \
345       break;                                                                   \
346     case ICV_LAST:                                                             \
347       break;                                                                   \
348     }                                                                          \
349   }
350 #include "llvm/Frontend/OpenMP/OMPKinds.def"
351   }
352 
353   /// Returns true if the function declaration \p F matches the runtime
354   /// function types, that is, return type \p RTFRetType, and argument types
355   /// \p RTFArgTypes.
356   static bool declMatchesRTFTypes(Function *F, Type *RTFRetType,
357                                   SmallVector<Type *, 8> &RTFArgTypes) {
358     // TODO: We should output information to the user (under debug output
359     //       and via remarks).
360 
361     if (!F)
362       return false;
363     if (F->getReturnType() != RTFRetType)
364       return false;
365     if (F->arg_size() != RTFArgTypes.size())
366       return false;
367 
368     auto RTFTyIt = RTFArgTypes.begin();
369     for (Argument &Arg : F->args()) {
370       if (Arg.getType() != *RTFTyIt)
371         return false;
372 
373       ++RTFTyIt;
374     }
375 
376     return true;
377   }
378 
379   // Helper to collect all uses of the declaration in the UsesMap.
380   unsigned collectUses(RuntimeFunctionInfo &RFI, bool CollectStats = true) {
381     unsigned NumUses = 0;
382     if (!RFI.Declaration)
383       return NumUses;
384     OMPBuilder.addAttributes(RFI.Kind, *RFI.Declaration);
385 
386     if (CollectStats) {
387       NumOpenMPRuntimeFunctionsIdentified += 1;
388       NumOpenMPRuntimeFunctionUsesIdentified += RFI.Declaration->getNumUses();
389     }
390 
391     // TODO: We directly convert uses into proper calls and unknown uses.
392     for (Use &U : RFI.Declaration->uses()) {
393       if (Instruction *UserI = dyn_cast<Instruction>(U.getUser())) {
394         if (ModuleSlice.count(UserI->getFunction())) {
395           RFI.getOrCreateUseVector(UserI->getFunction()).push_back(&U);
396           ++NumUses;
397         }
398       } else {
399         RFI.getOrCreateUseVector(nullptr).push_back(&U);
400         ++NumUses;
401       }
402     }
403     return NumUses;
404   }
405 
406   // Helper function to recollect uses of a runtime function.
407   void recollectUsesForFunction(RuntimeFunction RTF) {
408     auto &RFI = RFIs[RTF];
409     RFI.clearUsesMap();
410     collectUses(RFI, /*CollectStats*/ false);
411   }
412 
413   // Helper function to recollect uses of all runtime functions.
414   void recollectUses() {
415     for (int Idx = 0; Idx < RFIs.size(); ++Idx)
416       recollectUsesForFunction(static_cast<RuntimeFunction>(Idx));
417   }
418 
419   /// Helper to initialize all runtime function information for those defined
420   /// in OpenMPKinds.def.
421   void initializeRuntimeFunctions() {
422     Module &M = *((*ModuleSlice.begin())->getParent());
423 
424     // Helper macros for handling __VA_ARGS__ in OMP_RTL
425 #define OMP_TYPE(VarName, ...)                                                 \
426   Type *VarName = OMPBuilder.VarName;                                          \
427   (void)VarName;
428 
429 #define OMP_ARRAY_TYPE(VarName, ...)                                           \
430   ArrayType *VarName##Ty = OMPBuilder.VarName##Ty;                             \
431   (void)VarName##Ty;                                                           \
432   PointerType *VarName##PtrTy = OMPBuilder.VarName##PtrTy;                     \
433   (void)VarName##PtrTy;
434 
435 #define OMP_FUNCTION_TYPE(VarName, ...)                                        \
436   FunctionType *VarName = OMPBuilder.VarName;                                  \
437   (void)VarName;                                                               \
438   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
439   (void)VarName##Ptr;
440 
441 #define OMP_STRUCT_TYPE(VarName, ...)                                          \
442   StructType *VarName = OMPBuilder.VarName;                                    \
443   (void)VarName;                                                               \
444   PointerType *VarName##Ptr = OMPBuilder.VarName##Ptr;                         \
445   (void)VarName##Ptr;
446 
447 #define OMP_RTL(_Enum, _Name, _IsVarArg, _ReturnType, ...)                     \
448   {                                                                            \
449     SmallVector<Type *, 8> ArgsTypes({__VA_ARGS__});                           \
450     Function *F = M.getFunction(_Name);                                        \
451     RTLFunctions.insert(F);                                                    \
452     if (declMatchesRTFTypes(F, OMPBuilder._ReturnType, ArgsTypes)) {           \
453       RuntimeFunctionIDMap[F] = _Enum;                                         \
454       F->removeFnAttr(Attribute::NoInline);                                    \
455       auto &RFI = RFIs[_Enum];                                                 \
456       RFI.Kind = _Enum;                                                        \
457       RFI.Name = _Name;                                                        \
458       RFI.IsVarArg = _IsVarArg;                                                \
459       RFI.ReturnType = OMPBuilder._ReturnType;                                 \
460       RFI.ArgumentTypes = std::move(ArgsTypes);                                \
461       RFI.Declaration = F;                                                     \
462       unsigned NumUses = collectUses(RFI);                                     \
463       (void)NumUses;                                                           \
464       LLVM_DEBUG({                                                             \
465         dbgs() << TAG << RFI.Name << (RFI.Declaration ? "" : " not")           \
466                << " found\n";                                                  \
467         if (RFI.Declaration)                                                   \
468           dbgs() << TAG << "-> got " << NumUses << " uses in "                 \
469                  << RFI.getNumFunctionsWithUses()                              \
470                  << " different functions.\n";                                 \
471       });                                                                      \
472     }                                                                          \
473   }
474 #include "llvm/Frontend/OpenMP/OMPKinds.def"
475 
476     // TODO: We should attach the attributes defined in OMPKinds.def.
477   }
478 
479   /// Collection of known kernels (\see Kernel) in the module.
480   SmallPtrSetImpl<Kernel> &Kernels;
481 
482   /// Collection of known OpenMP runtime functions..
483   DenseSet<const Function *> RTLFunctions;
484 };
485 
486 template <typename Ty, bool InsertInvalidates = true>
487 struct BooleanStateWithSetVector : public BooleanState {
488   bool contains(const Ty &Elem) const { return Set.contains(Elem); }
489   bool insert(const Ty &Elem) {
490     if (InsertInvalidates)
491       BooleanState::indicatePessimisticFixpoint();
492     return Set.insert(Elem);
493   }
494 
495   const Ty &operator[](int Idx) const { return Set[Idx]; }
496   bool operator==(const BooleanStateWithSetVector &RHS) const {
497     return BooleanState::operator==(RHS) && Set == RHS.Set;
498   }
499   bool operator!=(const BooleanStateWithSetVector &RHS) const {
500     return !(*this == RHS);
501   }
502 
503   bool empty() const { return Set.empty(); }
504   size_t size() const { return Set.size(); }
505 
506   /// "Clamp" this state with \p RHS.
507   BooleanStateWithSetVector &operator^=(const BooleanStateWithSetVector &RHS) {
508     BooleanState::operator^=(RHS);
509     Set.insert(RHS.Set.begin(), RHS.Set.end());
510     return *this;
511   }
512 
513 private:
514   /// A set to keep track of elements.
515   SetVector<Ty> Set;
516 
517 public:
518   typename decltype(Set)::iterator begin() { return Set.begin(); }
519   typename decltype(Set)::iterator end() { return Set.end(); }
520   typename decltype(Set)::const_iterator begin() const { return Set.begin(); }
521   typename decltype(Set)::const_iterator end() const { return Set.end(); }
522 };
523 
524 template <typename Ty, bool InsertInvalidates = true>
525 using BooleanStateWithPtrSetVector =
526     BooleanStateWithSetVector<Ty *, InsertInvalidates>;
527 
528 struct KernelInfoState : AbstractState {
529   /// Flag to track if we reached a fixpoint.
530   bool IsAtFixpoint = false;
531 
532   /// The parallel regions (identified by the outlined parallel functions) that
533   /// can be reached from the associated function.
534   BooleanStateWithPtrSetVector<Function, /* InsertInvalidates */ false>
535       ReachedKnownParallelRegions;
536 
537   /// State to track what parallel region we might reach.
538   BooleanStateWithPtrSetVector<CallBase> ReachedUnknownParallelRegions;
539 
540   /// State to track if we are in SPMD-mode, assumed or know, and why we decided
541   /// we cannot be. If it is assumed, then RequiresFullRuntime should also be
542   /// false.
543   BooleanStateWithPtrSetVector<Instruction, false> SPMDCompatibilityTracker;
544 
545   /// The __kmpc_target_init call in this kernel, if any. If we find more than
546   /// one we abort as the kernel is malformed.
547   CallBase *KernelInitCB = nullptr;
548 
549   /// The __kmpc_target_deinit call in this kernel, if any. If we find more than
550   /// one we abort as the kernel is malformed.
551   CallBase *KernelDeinitCB = nullptr;
552 
553   /// Flag to indicate if the associated function is a kernel entry.
554   bool IsKernelEntry = false;
555 
556   /// State to track what kernel entries can reach the associated function.
557   BooleanStateWithPtrSetVector<Function, false> ReachingKernelEntries;
558 
559   /// State to indicate if we can track parallel level of the associated
560   /// function. We will give up tracking if we encounter unknown caller or the
561   /// caller is __kmpc_parallel_51.
562   BooleanStateWithSetVector<uint8_t> ParallelLevels;
563 
564   /// Abstract State interface
565   ///{
566 
567   KernelInfoState() {}
568   KernelInfoState(bool BestState) {
569     if (!BestState)
570       indicatePessimisticFixpoint();
571   }
572 
573   /// See AbstractState::isValidState(...)
574   bool isValidState() const override { return true; }
575 
576   /// See AbstractState::isAtFixpoint(...)
577   bool isAtFixpoint() const override { return IsAtFixpoint; }
578 
579   /// See AbstractState::indicatePessimisticFixpoint(...)
580   ChangeStatus indicatePessimisticFixpoint() override {
581     IsAtFixpoint = true;
582     SPMDCompatibilityTracker.indicatePessimisticFixpoint();
583     ReachedUnknownParallelRegions.indicatePessimisticFixpoint();
584     return ChangeStatus::CHANGED;
585   }
586 
587   /// See AbstractState::indicateOptimisticFixpoint(...)
588   ChangeStatus indicateOptimisticFixpoint() override {
589     IsAtFixpoint = true;
590     return ChangeStatus::UNCHANGED;
591   }
592 
593   /// Return the assumed state
594   KernelInfoState &getAssumed() { return *this; }
595   const KernelInfoState &getAssumed() const { return *this; }
596 
597   bool operator==(const KernelInfoState &RHS) const {
598     if (SPMDCompatibilityTracker != RHS.SPMDCompatibilityTracker)
599       return false;
600     if (ReachedKnownParallelRegions != RHS.ReachedKnownParallelRegions)
601       return false;
602     if (ReachedUnknownParallelRegions != RHS.ReachedUnknownParallelRegions)
603       return false;
604     if (ReachingKernelEntries != RHS.ReachingKernelEntries)
605       return false;
606     return true;
607   }
608 
609   /// Returns true if this kernel contains any OpenMP parallel regions.
610   bool mayContainParallelRegion() {
611     return !ReachedKnownParallelRegions.empty() ||
612            !ReachedUnknownParallelRegions.empty();
613   }
614 
615   /// Return empty set as the best state of potential values.
616   static KernelInfoState getBestState() { return KernelInfoState(true); }
617 
618   static KernelInfoState getBestState(KernelInfoState &KIS) {
619     return getBestState();
620   }
621 
622   /// Return full set as the worst state of potential values.
623   static KernelInfoState getWorstState() { return KernelInfoState(false); }
624 
625   /// "Clamp" this state with \p KIS.
626   KernelInfoState operator^=(const KernelInfoState &KIS) {
627     // Do not merge two different _init and _deinit call sites.
628     if (KIS.KernelInitCB) {
629       if (KernelInitCB && KernelInitCB != KIS.KernelInitCB)
630         indicatePessimisticFixpoint();
631       KernelInitCB = KIS.KernelInitCB;
632     }
633     if (KIS.KernelDeinitCB) {
634       if (KernelDeinitCB && KernelDeinitCB != KIS.KernelDeinitCB)
635         indicatePessimisticFixpoint();
636       KernelDeinitCB = KIS.KernelDeinitCB;
637     }
638     SPMDCompatibilityTracker ^= KIS.SPMDCompatibilityTracker;
639     ReachedKnownParallelRegions ^= KIS.ReachedKnownParallelRegions;
640     ReachedUnknownParallelRegions ^= KIS.ReachedUnknownParallelRegions;
641     return *this;
642   }
643 
644   KernelInfoState operator&=(const KernelInfoState &KIS) {
645     return (*this ^= KIS);
646   }
647 
648   ///}
649 };
650 
651 /// Used to map the values physically (in the IR) stored in an offload
652 /// array, to a vector in memory.
653 struct OffloadArray {
654   /// Physical array (in the IR).
655   AllocaInst *Array = nullptr;
656   /// Mapped values.
657   SmallVector<Value *, 8> StoredValues;
658   /// Last stores made in the offload array.
659   SmallVector<StoreInst *, 8> LastAccesses;
660 
661   OffloadArray() = default;
662 
663   /// Initializes the OffloadArray with the values stored in \p Array before
664   /// instruction \p Before is reached. Returns false if the initialization
665   /// fails.
666   /// This MUST be used immediately after the construction of the object.
667   bool initialize(AllocaInst &Array, Instruction &Before) {
668     if (!Array.getAllocatedType()->isArrayTy())
669       return false;
670 
671     if (!getValues(Array, Before))
672       return false;
673 
674     this->Array = &Array;
675     return true;
676   }
677 
678   static const unsigned DeviceIDArgNum = 1;
679   static const unsigned BasePtrsArgNum = 3;
680   static const unsigned PtrsArgNum = 4;
681   static const unsigned SizesArgNum = 5;
682 
683 private:
684   /// Traverses the BasicBlock where \p Array is, collecting the stores made to
685   /// \p Array, leaving StoredValues with the values stored before the
686   /// instruction \p Before is reached.
687   bool getValues(AllocaInst &Array, Instruction &Before) {
688     // Initialize container.
689     const uint64_t NumValues = Array.getAllocatedType()->getArrayNumElements();
690     StoredValues.assign(NumValues, nullptr);
691     LastAccesses.assign(NumValues, nullptr);
692 
693     // TODO: This assumes the instruction \p Before is in the same
694     //  BasicBlock as Array. Make it general, for any control flow graph.
695     BasicBlock *BB = Array.getParent();
696     if (BB != Before.getParent())
697       return false;
698 
699     const DataLayout &DL = Array.getModule()->getDataLayout();
700     const unsigned int PointerSize = DL.getPointerSize();
701 
702     for (Instruction &I : *BB) {
703       if (&I == &Before)
704         break;
705 
706       if (!isa<StoreInst>(&I))
707         continue;
708 
709       auto *S = cast<StoreInst>(&I);
710       int64_t Offset = -1;
711       auto *Dst =
712           GetPointerBaseWithConstantOffset(S->getPointerOperand(), Offset, DL);
713       if (Dst == &Array) {
714         int64_t Idx = Offset / PointerSize;
715         StoredValues[Idx] = getUnderlyingObject(S->getValueOperand());
716         LastAccesses[Idx] = S;
717       }
718     }
719 
720     return isFilled();
721   }
722 
723   /// Returns true if all values in StoredValues and
724   /// LastAccesses are not nullptrs.
725   bool isFilled() {
726     const unsigned NumValues = StoredValues.size();
727     for (unsigned I = 0; I < NumValues; ++I) {
728       if (!StoredValues[I] || !LastAccesses[I])
729         return false;
730     }
731 
732     return true;
733   }
734 };
735 
736 struct OpenMPOpt {
737 
738   using OptimizationRemarkGetter =
739       function_ref<OptimizationRemarkEmitter &(Function *)>;
740 
741   OpenMPOpt(SmallVectorImpl<Function *> &SCC, CallGraphUpdater &CGUpdater,
742             OptimizationRemarkGetter OREGetter,
743             OMPInformationCache &OMPInfoCache, Attributor &A)
744       : M(*(*SCC.begin())->getParent()), SCC(SCC), CGUpdater(CGUpdater),
745         OREGetter(OREGetter), OMPInfoCache(OMPInfoCache), A(A) {}
746 
747   /// Check if any remarks are enabled for openmp-opt
748   bool remarksEnabled() {
749     auto &Ctx = M.getContext();
750     return Ctx.getDiagHandlerPtr()->isAnyRemarkEnabled(DEBUG_TYPE);
751   }
752 
753   /// Run all OpenMP optimizations on the underlying SCC/ModuleSlice.
754   bool run(bool IsModulePass) {
755     if (SCC.empty())
756       return false;
757 
758     bool Changed = false;
759 
760     LLVM_DEBUG(dbgs() << TAG << "Run on SCC with " << SCC.size()
761                       << " functions in a slice with "
762                       << OMPInfoCache.ModuleSlice.size() << " functions\n");
763 
764     if (IsModulePass) {
765       Changed |= runAttributor(IsModulePass);
766 
767       // Recollect uses, in case Attributor deleted any.
768       OMPInfoCache.recollectUses();
769 
770       // TODO: This should be folded into buildCustomStateMachine.
771       Changed |= rewriteDeviceCodeStateMachine();
772 
773       if (remarksEnabled())
774         analysisGlobalization();
775     } else {
776       if (PrintICVValues)
777         printICVs();
778       if (PrintOpenMPKernels)
779         printKernels();
780 
781       Changed |= runAttributor(IsModulePass);
782 
783       // Recollect uses, in case Attributor deleted any.
784       OMPInfoCache.recollectUses();
785 
786       Changed |= deleteParallelRegions();
787 
788       if (HideMemoryTransferLatency)
789         Changed |= hideMemTransfersLatency();
790       Changed |= deduplicateRuntimeCalls();
791       if (EnableParallelRegionMerging) {
792         if (mergeParallelRegions()) {
793           deduplicateRuntimeCalls();
794           Changed = true;
795         }
796       }
797     }
798 
799     return Changed;
800   }
801 
802   /// Print initial ICV values for testing.
803   /// FIXME: This should be done from the Attributor once it is added.
804   void printICVs() const {
805     InternalControlVar ICVs[] = {ICV_nthreads, ICV_active_levels, ICV_cancel,
806                                  ICV_proc_bind};
807 
808     for (Function *F : OMPInfoCache.ModuleSlice) {
809       for (auto ICV : ICVs) {
810         auto ICVInfo = OMPInfoCache.ICVs[ICV];
811         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
812           return ORA << "OpenMP ICV " << ore::NV("OpenMPICV", ICVInfo.Name)
813                      << " Value: "
814                      << (ICVInfo.InitValue
815                              ? toString(ICVInfo.InitValue->getValue(), 10, true)
816                              : "IMPLEMENTATION_DEFINED");
817         };
818 
819         emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPICVTracker", Remark);
820       }
821     }
822   }
823 
824   /// Print OpenMP GPU kernels for testing.
825   void printKernels() const {
826     for (Function *F : SCC) {
827       if (!OMPInfoCache.Kernels.count(F))
828         continue;
829 
830       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
831         return ORA << "OpenMP GPU kernel "
832                    << ore::NV("OpenMPGPUKernel", F->getName()) << "\n";
833       };
834 
835       emitRemark<OptimizationRemarkAnalysis>(F, "OpenMPGPU", Remark);
836     }
837   }
838 
839   /// Return the call if \p U is a callee use in a regular call. If \p RFI is
840   /// given it has to be the callee or a nullptr is returned.
841   static CallInst *getCallIfRegularCall(
842       Use &U, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
843     CallInst *CI = dyn_cast<CallInst>(U.getUser());
844     if (CI && CI->isCallee(&U) && !CI->hasOperandBundles() &&
845         (!RFI ||
846          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
847       return CI;
848     return nullptr;
849   }
850 
851   /// Return the call if \p V is a regular call. If \p RFI is given it has to be
852   /// the callee or a nullptr is returned.
853   static CallInst *getCallIfRegularCall(
854       Value &V, OMPInformationCache::RuntimeFunctionInfo *RFI = nullptr) {
855     CallInst *CI = dyn_cast<CallInst>(&V);
856     if (CI && !CI->hasOperandBundles() &&
857         (!RFI ||
858          (RFI->Declaration && CI->getCalledFunction() == RFI->Declaration)))
859       return CI;
860     return nullptr;
861   }
862 
863 private:
864   /// Merge parallel regions when it is safe.
865   bool mergeParallelRegions() {
866     const unsigned CallbackCalleeOperand = 2;
867     const unsigned CallbackFirstArgOperand = 3;
868     using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
869 
870     // Check if there are any __kmpc_fork_call calls to merge.
871     OMPInformationCache::RuntimeFunctionInfo &RFI =
872         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
873 
874     if (!RFI.Declaration)
875       return false;
876 
877     // Unmergable calls that prevent merging a parallel region.
878     OMPInformationCache::RuntimeFunctionInfo UnmergableCallsInfo[] = {
879         OMPInfoCache.RFIs[OMPRTL___kmpc_push_proc_bind],
880         OMPInfoCache.RFIs[OMPRTL___kmpc_push_num_threads],
881     };
882 
883     bool Changed = false;
884     LoopInfo *LI = nullptr;
885     DominatorTree *DT = nullptr;
886 
887     SmallDenseMap<BasicBlock *, SmallPtrSet<Instruction *, 4>> BB2PRMap;
888 
889     BasicBlock *StartBB = nullptr, *EndBB = nullptr;
890     auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
891                          BasicBlock &ContinuationIP) {
892       BasicBlock *CGStartBB = CodeGenIP.getBlock();
893       BasicBlock *CGEndBB =
894           SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
895       assert(StartBB != nullptr && "StartBB should not be null");
896       CGStartBB->getTerminator()->setSuccessor(0, StartBB);
897       assert(EndBB != nullptr && "EndBB should not be null");
898       EndBB->getTerminator()->setSuccessor(0, CGEndBB);
899     };
900 
901     auto PrivCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP, Value &,
902                       Value &Inner, Value *&ReplacementValue) -> InsertPointTy {
903       ReplacementValue = &Inner;
904       return CodeGenIP;
905     };
906 
907     auto FiniCB = [&](InsertPointTy CodeGenIP) {};
908 
909     /// Create a sequential execution region within a merged parallel region,
910     /// encapsulated in a master construct with a barrier for synchronization.
911     auto CreateSequentialRegion = [&](Function *OuterFn,
912                                       BasicBlock *OuterPredBB,
913                                       Instruction *SeqStartI,
914                                       Instruction *SeqEndI) {
915       // Isolate the instructions of the sequential region to a separate
916       // block.
917       BasicBlock *ParentBB = SeqStartI->getParent();
918       BasicBlock *SeqEndBB =
919           SplitBlock(ParentBB, SeqEndI->getNextNode(), DT, LI);
920       BasicBlock *SeqAfterBB =
921           SplitBlock(SeqEndBB, &*SeqEndBB->getFirstInsertionPt(), DT, LI);
922       BasicBlock *SeqStartBB =
923           SplitBlock(ParentBB, SeqStartI, DT, LI, nullptr, "seq.par.merged");
924 
925       assert(ParentBB->getUniqueSuccessor() == SeqStartBB &&
926              "Expected a different CFG");
927       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
928       ParentBB->getTerminator()->eraseFromParent();
929 
930       auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
931                            BasicBlock &ContinuationIP) {
932         BasicBlock *CGStartBB = CodeGenIP.getBlock();
933         BasicBlock *CGEndBB =
934             SplitBlock(CGStartBB, &*CodeGenIP.getPoint(), DT, LI);
935         assert(SeqStartBB != nullptr && "SeqStartBB should not be null");
936         CGStartBB->getTerminator()->setSuccessor(0, SeqStartBB);
937         assert(SeqEndBB != nullptr && "SeqEndBB should not be null");
938         SeqEndBB->getTerminator()->setSuccessor(0, CGEndBB);
939       };
940       auto FiniCB = [&](InsertPointTy CodeGenIP) {};
941 
942       // Find outputs from the sequential region to outside users and
943       // broadcast their values to them.
944       for (Instruction &I : *SeqStartBB) {
945         SmallPtrSet<Instruction *, 4> OutsideUsers;
946         for (User *Usr : I.users()) {
947           Instruction &UsrI = *cast<Instruction>(Usr);
948           // Ignore outputs to LT intrinsics, code extraction for the merged
949           // parallel region will fix them.
950           if (UsrI.isLifetimeStartOrEnd())
951             continue;
952 
953           if (UsrI.getParent() != SeqStartBB)
954             OutsideUsers.insert(&UsrI);
955         }
956 
957         if (OutsideUsers.empty())
958           continue;
959 
960         // Emit an alloca in the outer region to store the broadcasted
961         // value.
962         const DataLayout &DL = M.getDataLayout();
963         AllocaInst *AllocaI = new AllocaInst(
964             I.getType(), DL.getAllocaAddrSpace(), nullptr,
965             I.getName() + ".seq.output.alloc", &OuterFn->front().front());
966 
967         // Emit a store instruction in the sequential BB to update the
968         // value.
969         new StoreInst(&I, AllocaI, SeqStartBB->getTerminator());
970 
971         // Emit a load instruction and replace the use of the output value
972         // with it.
973         for (Instruction *UsrI : OutsideUsers) {
974           LoadInst *LoadI = new LoadInst(
975               I.getType(), AllocaI, I.getName() + ".seq.output.load", UsrI);
976           UsrI->replaceUsesOfWith(&I, LoadI);
977         }
978       }
979 
980       OpenMPIRBuilder::LocationDescription Loc(
981           InsertPointTy(ParentBB, ParentBB->end()), DL);
982       InsertPointTy SeqAfterIP =
983           OMPInfoCache.OMPBuilder.createMaster(Loc, BodyGenCB, FiniCB);
984 
985       OMPInfoCache.OMPBuilder.createBarrier(SeqAfterIP, OMPD_parallel);
986 
987       BranchInst::Create(SeqAfterBB, SeqAfterIP.getBlock());
988 
989       LLVM_DEBUG(dbgs() << TAG << "After sequential inlining " << *OuterFn
990                         << "\n");
991     };
992 
993     // Helper to merge the __kmpc_fork_call calls in MergableCIs. They are all
994     // contained in BB and only separated by instructions that can be
995     // redundantly executed in parallel. The block BB is split before the first
996     // call (in MergableCIs) and after the last so the entire region we merge
997     // into a single parallel region is contained in a single basic block
998     // without any other instructions. We use the OpenMPIRBuilder to outline
999     // that block and call the resulting function via __kmpc_fork_call.
1000     auto Merge = [&](SmallVectorImpl<CallInst *> &MergableCIs, BasicBlock *BB) {
1001       // TODO: Change the interface to allow single CIs expanded, e.g, to
1002       // include an outer loop.
1003       assert(MergableCIs.size() > 1 && "Assumed multiple mergable CIs");
1004 
1005       auto Remark = [&](OptimizationRemark OR) {
1006         OR << "Parallel region merged with parallel region"
1007            << (MergableCIs.size() > 2 ? "s" : "") << " at ";
1008         for (auto *CI : llvm::drop_begin(MergableCIs)) {
1009           OR << ore::NV("OpenMPParallelMerge", CI->getDebugLoc());
1010           if (CI != MergableCIs.back())
1011             OR << ", ";
1012         }
1013         return OR << ".";
1014       };
1015 
1016       emitRemark<OptimizationRemark>(MergableCIs.front(), "OMP150", Remark);
1017 
1018       Function *OriginalFn = BB->getParent();
1019       LLVM_DEBUG(dbgs() << TAG << "Merge " << MergableCIs.size()
1020                         << " parallel regions in " << OriginalFn->getName()
1021                         << "\n");
1022 
1023       // Isolate the calls to merge in a separate block.
1024       EndBB = SplitBlock(BB, MergableCIs.back()->getNextNode(), DT, LI);
1025       BasicBlock *AfterBB =
1026           SplitBlock(EndBB, &*EndBB->getFirstInsertionPt(), DT, LI);
1027       StartBB = SplitBlock(BB, MergableCIs.front(), DT, LI, nullptr,
1028                            "omp.par.merged");
1029 
1030       assert(BB->getUniqueSuccessor() == StartBB && "Expected a different CFG");
1031       const DebugLoc DL = BB->getTerminator()->getDebugLoc();
1032       BB->getTerminator()->eraseFromParent();
1033 
1034       // Create sequential regions for sequential instructions that are
1035       // in-between mergable parallel regions.
1036       for (auto *It = MergableCIs.begin(), *End = MergableCIs.end() - 1;
1037            It != End; ++It) {
1038         Instruction *ForkCI = *It;
1039         Instruction *NextForkCI = *(It + 1);
1040 
1041         // Continue if there are not in-between instructions.
1042         if (ForkCI->getNextNode() == NextForkCI)
1043           continue;
1044 
1045         CreateSequentialRegion(OriginalFn, BB, ForkCI->getNextNode(),
1046                                NextForkCI->getPrevNode());
1047       }
1048 
1049       OpenMPIRBuilder::LocationDescription Loc(InsertPointTy(BB, BB->end()),
1050                                                DL);
1051       IRBuilder<>::InsertPoint AllocaIP(
1052           &OriginalFn->getEntryBlock(),
1053           OriginalFn->getEntryBlock().getFirstInsertionPt());
1054       // Create the merged parallel region with default proc binding, to
1055       // avoid overriding binding settings, and without explicit cancellation.
1056       InsertPointTy AfterIP = OMPInfoCache.OMPBuilder.createParallel(
1057           Loc, AllocaIP, BodyGenCB, PrivCB, FiniCB, nullptr, nullptr,
1058           OMP_PROC_BIND_default, /* IsCancellable */ false);
1059       BranchInst::Create(AfterBB, AfterIP.getBlock());
1060 
1061       // Perform the actual outlining.
1062       OMPInfoCache.OMPBuilder.finalize(OriginalFn,
1063                                        /* AllowExtractorSinking */ true);
1064 
1065       Function *OutlinedFn = MergableCIs.front()->getCaller();
1066 
1067       // Replace the __kmpc_fork_call calls with direct calls to the outlined
1068       // callbacks.
1069       SmallVector<Value *, 8> Args;
1070       for (auto *CI : MergableCIs) {
1071         Value *Callee =
1072             CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts();
1073         FunctionType *FT =
1074             cast<FunctionType>(Callee->getType()->getPointerElementType());
1075         Args.clear();
1076         Args.push_back(OutlinedFn->getArg(0));
1077         Args.push_back(OutlinedFn->getArg(1));
1078         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
1079              U < E; ++U)
1080           Args.push_back(CI->getArgOperand(U));
1081 
1082         CallInst *NewCI = CallInst::Create(FT, Callee, Args, "", CI);
1083         if (CI->getDebugLoc())
1084           NewCI->setDebugLoc(CI->getDebugLoc());
1085 
1086         // Forward parameter attributes from the callback to the callee.
1087         for (unsigned U = CallbackFirstArgOperand, E = CI->getNumArgOperands();
1088              U < E; ++U)
1089           for (const Attribute &A : CI->getAttributes().getParamAttrs(U))
1090             NewCI->addParamAttr(
1091                 U - (CallbackFirstArgOperand - CallbackCalleeOperand), A);
1092 
1093         // Emit an explicit barrier to replace the implicit fork-join barrier.
1094         if (CI != MergableCIs.back()) {
1095           // TODO: Remove barrier if the merged parallel region includes the
1096           // 'nowait' clause.
1097           OMPInfoCache.OMPBuilder.createBarrier(
1098               InsertPointTy(NewCI->getParent(),
1099                             NewCI->getNextNode()->getIterator()),
1100               OMPD_parallel);
1101         }
1102 
1103         CI->eraseFromParent();
1104       }
1105 
1106       assert(OutlinedFn != OriginalFn && "Outlining failed");
1107       CGUpdater.registerOutlinedFunction(*OriginalFn, *OutlinedFn);
1108       CGUpdater.reanalyzeFunction(*OriginalFn);
1109 
1110       NumOpenMPParallelRegionsMerged += MergableCIs.size();
1111 
1112       return true;
1113     };
1114 
1115     // Helper function that identifes sequences of
1116     // __kmpc_fork_call uses in a basic block.
1117     auto DetectPRsCB = [&](Use &U, Function &F) {
1118       CallInst *CI = getCallIfRegularCall(U, &RFI);
1119       BB2PRMap[CI->getParent()].insert(CI);
1120 
1121       return false;
1122     };
1123 
1124     BB2PRMap.clear();
1125     RFI.foreachUse(SCC, DetectPRsCB);
1126     SmallVector<SmallVector<CallInst *, 4>, 4> MergableCIsVector;
1127     // Find mergable parallel regions within a basic block that are
1128     // safe to merge, that is any in-between instructions can safely
1129     // execute in parallel after merging.
1130     // TODO: support merging across basic-blocks.
1131     for (auto &It : BB2PRMap) {
1132       auto &CIs = It.getSecond();
1133       if (CIs.size() < 2)
1134         continue;
1135 
1136       BasicBlock *BB = It.getFirst();
1137       SmallVector<CallInst *, 4> MergableCIs;
1138 
1139       /// Returns true if the instruction is mergable, false otherwise.
1140       /// A terminator instruction is unmergable by definition since merging
1141       /// works within a BB. Instructions before the mergable region are
1142       /// mergable if they are not calls to OpenMP runtime functions that may
1143       /// set different execution parameters for subsequent parallel regions.
1144       /// Instructions in-between parallel regions are mergable if they are not
1145       /// calls to any non-intrinsic function since that may call a non-mergable
1146       /// OpenMP runtime function.
1147       auto IsMergable = [&](Instruction &I, bool IsBeforeMergableRegion) {
1148         // We do not merge across BBs, hence return false (unmergable) if the
1149         // instruction is a terminator.
1150         if (I.isTerminator())
1151           return false;
1152 
1153         if (!isa<CallInst>(&I))
1154           return true;
1155 
1156         CallInst *CI = cast<CallInst>(&I);
1157         if (IsBeforeMergableRegion) {
1158           Function *CalledFunction = CI->getCalledFunction();
1159           if (!CalledFunction)
1160             return false;
1161           // Return false (unmergable) if the call before the parallel
1162           // region calls an explicit affinity (proc_bind) or number of
1163           // threads (num_threads) compiler-generated function. Those settings
1164           // may be incompatible with following parallel regions.
1165           // TODO: ICV tracking to detect compatibility.
1166           for (const auto &RFI : UnmergableCallsInfo) {
1167             if (CalledFunction == RFI.Declaration)
1168               return false;
1169           }
1170         } else {
1171           // Return false (unmergable) if there is a call instruction
1172           // in-between parallel regions when it is not an intrinsic. It
1173           // may call an unmergable OpenMP runtime function in its callpath.
1174           // TODO: Keep track of possible OpenMP calls in the callpath.
1175           if (!isa<IntrinsicInst>(CI))
1176             return false;
1177         }
1178 
1179         return true;
1180       };
1181       // Find maximal number of parallel region CIs that are safe to merge.
1182       for (auto It = BB->begin(), End = BB->end(); It != End;) {
1183         Instruction &I = *It;
1184         ++It;
1185 
1186         if (CIs.count(&I)) {
1187           MergableCIs.push_back(cast<CallInst>(&I));
1188           continue;
1189         }
1190 
1191         // Continue expanding if the instruction is mergable.
1192         if (IsMergable(I, MergableCIs.empty()))
1193           continue;
1194 
1195         // Forward the instruction iterator to skip the next parallel region
1196         // since there is an unmergable instruction which can affect it.
1197         for (; It != End; ++It) {
1198           Instruction &SkipI = *It;
1199           if (CIs.count(&SkipI)) {
1200             LLVM_DEBUG(dbgs() << TAG << "Skip parallel region " << SkipI
1201                               << " due to " << I << "\n");
1202             ++It;
1203             break;
1204           }
1205         }
1206 
1207         // Store mergable regions found.
1208         if (MergableCIs.size() > 1) {
1209           MergableCIsVector.push_back(MergableCIs);
1210           LLVM_DEBUG(dbgs() << TAG << "Found " << MergableCIs.size()
1211                             << " parallel regions in block " << BB->getName()
1212                             << " of function " << BB->getParent()->getName()
1213                             << "\n";);
1214         }
1215 
1216         MergableCIs.clear();
1217       }
1218 
1219       if (!MergableCIsVector.empty()) {
1220         Changed = true;
1221 
1222         for (auto &MergableCIs : MergableCIsVector)
1223           Merge(MergableCIs, BB);
1224         MergableCIsVector.clear();
1225       }
1226     }
1227 
1228     if (Changed) {
1229       /// Re-collect use for fork calls, emitted barrier calls, and
1230       /// any emitted master/end_master calls.
1231       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_fork_call);
1232       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_barrier);
1233       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_master);
1234       OMPInfoCache.recollectUsesForFunction(OMPRTL___kmpc_end_master);
1235     }
1236 
1237     return Changed;
1238   }
1239 
1240   /// Try to delete parallel regions if possible.
1241   bool deleteParallelRegions() {
1242     const unsigned CallbackCalleeOperand = 2;
1243 
1244     OMPInformationCache::RuntimeFunctionInfo &RFI =
1245         OMPInfoCache.RFIs[OMPRTL___kmpc_fork_call];
1246 
1247     if (!RFI.Declaration)
1248       return false;
1249 
1250     bool Changed = false;
1251     auto DeleteCallCB = [&](Use &U, Function &) {
1252       CallInst *CI = getCallIfRegularCall(U);
1253       if (!CI)
1254         return false;
1255       auto *Fn = dyn_cast<Function>(
1256           CI->getArgOperand(CallbackCalleeOperand)->stripPointerCasts());
1257       if (!Fn)
1258         return false;
1259       if (!Fn->onlyReadsMemory())
1260         return false;
1261       if (!Fn->hasFnAttribute(Attribute::WillReturn))
1262         return false;
1263 
1264       LLVM_DEBUG(dbgs() << TAG << "Delete read-only parallel region in "
1265                         << CI->getCaller()->getName() << "\n");
1266 
1267       auto Remark = [&](OptimizationRemark OR) {
1268         return OR << "Removing parallel region with no side-effects.";
1269       };
1270       emitRemark<OptimizationRemark>(CI, "OMP160", Remark);
1271 
1272       CGUpdater.removeCallSite(*CI);
1273       CI->eraseFromParent();
1274       Changed = true;
1275       ++NumOpenMPParallelRegionsDeleted;
1276       return true;
1277     };
1278 
1279     RFI.foreachUse(SCC, DeleteCallCB);
1280 
1281     return Changed;
1282   }
1283 
1284   /// Try to eliminate runtime calls by reusing existing ones.
1285   bool deduplicateRuntimeCalls() {
1286     bool Changed = false;
1287 
1288     RuntimeFunction DeduplicableRuntimeCallIDs[] = {
1289         OMPRTL_omp_get_num_threads,
1290         OMPRTL_omp_in_parallel,
1291         OMPRTL_omp_get_cancellation,
1292         OMPRTL_omp_get_thread_limit,
1293         OMPRTL_omp_get_supported_active_levels,
1294         OMPRTL_omp_get_level,
1295         OMPRTL_omp_get_ancestor_thread_num,
1296         OMPRTL_omp_get_team_size,
1297         OMPRTL_omp_get_active_level,
1298         OMPRTL_omp_in_final,
1299         OMPRTL_omp_get_proc_bind,
1300         OMPRTL_omp_get_num_places,
1301         OMPRTL_omp_get_num_procs,
1302         OMPRTL_omp_get_place_num,
1303         OMPRTL_omp_get_partition_num_places,
1304         OMPRTL_omp_get_partition_place_nums};
1305 
1306     // Global-tid is handled separately.
1307     SmallSetVector<Value *, 16> GTIdArgs;
1308     collectGlobalThreadIdArguments(GTIdArgs);
1309     LLVM_DEBUG(dbgs() << TAG << "Found " << GTIdArgs.size()
1310                       << " global thread ID arguments\n");
1311 
1312     for (Function *F : SCC) {
1313       for (auto DeduplicableRuntimeCallID : DeduplicableRuntimeCallIDs)
1314         Changed |= deduplicateRuntimeCalls(
1315             *F, OMPInfoCache.RFIs[DeduplicableRuntimeCallID]);
1316 
1317       // __kmpc_global_thread_num is special as we can replace it with an
1318       // argument in enough cases to make it worth trying.
1319       Value *GTIdArg = nullptr;
1320       for (Argument &Arg : F->args())
1321         if (GTIdArgs.count(&Arg)) {
1322           GTIdArg = &Arg;
1323           break;
1324         }
1325       Changed |= deduplicateRuntimeCalls(
1326           *F, OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num], GTIdArg);
1327     }
1328 
1329     return Changed;
1330   }
1331 
1332   /// Tries to hide the latency of runtime calls that involve host to
1333   /// device memory transfers by splitting them into their "issue" and "wait"
1334   /// versions. The "issue" is moved upwards as much as possible. The "wait" is
1335   /// moved downards as much as possible. The "issue" issues the memory transfer
1336   /// asynchronously, returning a handle. The "wait" waits in the returned
1337   /// handle for the memory transfer to finish.
1338   bool hideMemTransfersLatency() {
1339     auto &RFI = OMPInfoCache.RFIs[OMPRTL___tgt_target_data_begin_mapper];
1340     bool Changed = false;
1341     auto SplitMemTransfers = [&](Use &U, Function &Decl) {
1342       auto *RTCall = getCallIfRegularCall(U, &RFI);
1343       if (!RTCall)
1344         return false;
1345 
1346       OffloadArray OffloadArrays[3];
1347       if (!getValuesInOffloadArrays(*RTCall, OffloadArrays))
1348         return false;
1349 
1350       LLVM_DEBUG(dumpValuesInOffloadArrays(OffloadArrays));
1351 
1352       // TODO: Check if can be moved upwards.
1353       bool WasSplit = false;
1354       Instruction *WaitMovementPoint = canBeMovedDownwards(*RTCall);
1355       if (WaitMovementPoint)
1356         WasSplit = splitTargetDataBeginRTC(*RTCall, *WaitMovementPoint);
1357 
1358       Changed |= WasSplit;
1359       return WasSplit;
1360     };
1361     RFI.foreachUse(SCC, SplitMemTransfers);
1362 
1363     return Changed;
1364   }
1365 
1366   void analysisGlobalization() {
1367     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
1368 
1369     auto CheckGlobalization = [&](Use &U, Function &Decl) {
1370       if (CallInst *CI = getCallIfRegularCall(U, &RFI)) {
1371         auto Remark = [&](OptimizationRemarkMissed ORM) {
1372           return ORM
1373                  << "Found thread data sharing on the GPU. "
1374                  << "Expect degraded performance due to data globalization.";
1375         };
1376         emitRemark<OptimizationRemarkMissed>(CI, "OMP112", Remark);
1377       }
1378 
1379       return false;
1380     };
1381 
1382     RFI.foreachUse(SCC, CheckGlobalization);
1383   }
1384 
1385   /// Maps the values stored in the offload arrays passed as arguments to
1386   /// \p RuntimeCall into the offload arrays in \p OAs.
1387   bool getValuesInOffloadArrays(CallInst &RuntimeCall,
1388                                 MutableArrayRef<OffloadArray> OAs) {
1389     assert(OAs.size() == 3 && "Need space for three offload arrays!");
1390 
1391     // A runtime call that involves memory offloading looks something like:
1392     // call void @__tgt_target_data_begin_mapper(arg0, arg1,
1393     //   i8** %offload_baseptrs, i8** %offload_ptrs, i64* %offload_sizes,
1394     // ...)
1395     // So, the idea is to access the allocas that allocate space for these
1396     // offload arrays, offload_baseptrs, offload_ptrs, offload_sizes.
1397     // Therefore:
1398     // i8** %offload_baseptrs.
1399     Value *BasePtrsArg =
1400         RuntimeCall.getArgOperand(OffloadArray::BasePtrsArgNum);
1401     // i8** %offload_ptrs.
1402     Value *PtrsArg = RuntimeCall.getArgOperand(OffloadArray::PtrsArgNum);
1403     // i8** %offload_sizes.
1404     Value *SizesArg = RuntimeCall.getArgOperand(OffloadArray::SizesArgNum);
1405 
1406     // Get values stored in **offload_baseptrs.
1407     auto *V = getUnderlyingObject(BasePtrsArg);
1408     if (!isa<AllocaInst>(V))
1409       return false;
1410     auto *BasePtrsArray = cast<AllocaInst>(V);
1411     if (!OAs[0].initialize(*BasePtrsArray, RuntimeCall))
1412       return false;
1413 
1414     // Get values stored in **offload_baseptrs.
1415     V = getUnderlyingObject(PtrsArg);
1416     if (!isa<AllocaInst>(V))
1417       return false;
1418     auto *PtrsArray = cast<AllocaInst>(V);
1419     if (!OAs[1].initialize(*PtrsArray, RuntimeCall))
1420       return false;
1421 
1422     // Get values stored in **offload_sizes.
1423     V = getUnderlyingObject(SizesArg);
1424     // If it's a [constant] global array don't analyze it.
1425     if (isa<GlobalValue>(V))
1426       return isa<Constant>(V);
1427     if (!isa<AllocaInst>(V))
1428       return false;
1429 
1430     auto *SizesArray = cast<AllocaInst>(V);
1431     if (!OAs[2].initialize(*SizesArray, RuntimeCall))
1432       return false;
1433 
1434     return true;
1435   }
1436 
1437   /// Prints the values in the OffloadArrays \p OAs using LLVM_DEBUG.
1438   /// For now this is a way to test that the function getValuesInOffloadArrays
1439   /// is working properly.
1440   /// TODO: Move this to a unittest when unittests are available for OpenMPOpt.
1441   void dumpValuesInOffloadArrays(ArrayRef<OffloadArray> OAs) {
1442     assert(OAs.size() == 3 && "There are three offload arrays to debug!");
1443 
1444     LLVM_DEBUG(dbgs() << TAG << " Successfully got offload values:\n");
1445     std::string ValuesStr;
1446     raw_string_ostream Printer(ValuesStr);
1447     std::string Separator = " --- ";
1448 
1449     for (auto *BP : OAs[0].StoredValues) {
1450       BP->print(Printer);
1451       Printer << Separator;
1452     }
1453     LLVM_DEBUG(dbgs() << "\t\toffload_baseptrs: " << Printer.str() << "\n");
1454     ValuesStr.clear();
1455 
1456     for (auto *P : OAs[1].StoredValues) {
1457       P->print(Printer);
1458       Printer << Separator;
1459     }
1460     LLVM_DEBUG(dbgs() << "\t\toffload_ptrs: " << Printer.str() << "\n");
1461     ValuesStr.clear();
1462 
1463     for (auto *S : OAs[2].StoredValues) {
1464       S->print(Printer);
1465       Printer << Separator;
1466     }
1467     LLVM_DEBUG(dbgs() << "\t\toffload_sizes: " << Printer.str() << "\n");
1468   }
1469 
1470   /// Returns the instruction where the "wait" counterpart \p RuntimeCall can be
1471   /// moved. Returns nullptr if the movement is not possible, or not worth it.
1472   Instruction *canBeMovedDownwards(CallInst &RuntimeCall) {
1473     // FIXME: This traverses only the BasicBlock where RuntimeCall is.
1474     //  Make it traverse the CFG.
1475 
1476     Instruction *CurrentI = &RuntimeCall;
1477     bool IsWorthIt = false;
1478     while ((CurrentI = CurrentI->getNextNode())) {
1479 
1480       // TODO: Once we detect the regions to be offloaded we should use the
1481       //  alias analysis manager to check if CurrentI may modify one of
1482       //  the offloaded regions.
1483       if (CurrentI->mayHaveSideEffects() || CurrentI->mayReadFromMemory()) {
1484         if (IsWorthIt)
1485           return CurrentI;
1486 
1487         return nullptr;
1488       }
1489 
1490       // FIXME: For now if we move it over anything without side effect
1491       //  is worth it.
1492       IsWorthIt = true;
1493     }
1494 
1495     // Return end of BasicBlock.
1496     return RuntimeCall.getParent()->getTerminator();
1497   }
1498 
1499   /// Splits \p RuntimeCall into its "issue" and "wait" counterparts.
1500   bool splitTargetDataBeginRTC(CallInst &RuntimeCall,
1501                                Instruction &WaitMovementPoint) {
1502     // Create stack allocated handle (__tgt_async_info) at the beginning of the
1503     // function. Used for storing information of the async transfer, allowing to
1504     // wait on it later.
1505     auto &IRBuilder = OMPInfoCache.OMPBuilder;
1506     auto *F = RuntimeCall.getCaller();
1507     Instruction *FirstInst = &(F->getEntryBlock().front());
1508     AllocaInst *Handle = new AllocaInst(
1509         IRBuilder.AsyncInfo, F->getAddressSpace(), "handle", FirstInst);
1510 
1511     // Add "issue" runtime call declaration:
1512     // declare %struct.tgt_async_info @__tgt_target_data_begin_issue(i64, i32,
1513     //   i8**, i8**, i64*, i64*)
1514     FunctionCallee IssueDecl = IRBuilder.getOrCreateRuntimeFunction(
1515         M, OMPRTL___tgt_target_data_begin_mapper_issue);
1516 
1517     // Change RuntimeCall call site for its asynchronous version.
1518     SmallVector<Value *, 16> Args;
1519     for (auto &Arg : RuntimeCall.args())
1520       Args.push_back(Arg.get());
1521     Args.push_back(Handle);
1522 
1523     CallInst *IssueCallsite =
1524         CallInst::Create(IssueDecl, Args, /*NameStr=*/"", &RuntimeCall);
1525     RuntimeCall.eraseFromParent();
1526 
1527     // Add "wait" runtime call declaration:
1528     // declare void @__tgt_target_data_begin_wait(i64, %struct.__tgt_async_info)
1529     FunctionCallee WaitDecl = IRBuilder.getOrCreateRuntimeFunction(
1530         M, OMPRTL___tgt_target_data_begin_mapper_wait);
1531 
1532     Value *WaitParams[2] = {
1533         IssueCallsite->getArgOperand(
1534             OffloadArray::DeviceIDArgNum), // device_id.
1535         Handle                             // handle to wait on.
1536     };
1537     CallInst::Create(WaitDecl, WaitParams, /*NameStr=*/"", &WaitMovementPoint);
1538 
1539     return true;
1540   }
1541 
1542   static Value *combinedIdentStruct(Value *CurrentIdent, Value *NextIdent,
1543                                     bool GlobalOnly, bool &SingleChoice) {
1544     if (CurrentIdent == NextIdent)
1545       return CurrentIdent;
1546 
1547     // TODO: Figure out how to actually combine multiple debug locations. For
1548     //       now we just keep an existing one if there is a single choice.
1549     if (!GlobalOnly || isa<GlobalValue>(NextIdent)) {
1550       SingleChoice = !CurrentIdent;
1551       return NextIdent;
1552     }
1553     return nullptr;
1554   }
1555 
1556   /// Return an `struct ident_t*` value that represents the ones used in the
1557   /// calls of \p RFI inside of \p F. If \p GlobalOnly is true, we will not
1558   /// return a local `struct ident_t*`. For now, if we cannot find a suitable
1559   /// return value we create one from scratch. We also do not yet combine
1560   /// information, e.g., the source locations, see combinedIdentStruct.
1561   Value *
1562   getCombinedIdentFromCallUsesIn(OMPInformationCache::RuntimeFunctionInfo &RFI,
1563                                  Function &F, bool GlobalOnly) {
1564     bool SingleChoice = true;
1565     Value *Ident = nullptr;
1566     auto CombineIdentStruct = [&](Use &U, Function &Caller) {
1567       CallInst *CI = getCallIfRegularCall(U, &RFI);
1568       if (!CI || &F != &Caller)
1569         return false;
1570       Ident = combinedIdentStruct(Ident, CI->getArgOperand(0),
1571                                   /* GlobalOnly */ true, SingleChoice);
1572       return false;
1573     };
1574     RFI.foreachUse(SCC, CombineIdentStruct);
1575 
1576     if (!Ident || !SingleChoice) {
1577       // The IRBuilder uses the insertion block to get to the module, this is
1578       // unfortunate but we work around it for now.
1579       if (!OMPInfoCache.OMPBuilder.getInsertionPoint().getBlock())
1580         OMPInfoCache.OMPBuilder.updateToLocation(OpenMPIRBuilder::InsertPointTy(
1581             &F.getEntryBlock(), F.getEntryBlock().begin()));
1582       // Create a fallback location if non was found.
1583       // TODO: Use the debug locations of the calls instead.
1584       Constant *Loc = OMPInfoCache.OMPBuilder.getOrCreateDefaultSrcLocStr();
1585       Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(Loc);
1586     }
1587     return Ident;
1588   }
1589 
1590   /// Try to eliminate calls of \p RFI in \p F by reusing an existing one or
1591   /// \p ReplVal if given.
1592   bool deduplicateRuntimeCalls(Function &F,
1593                                OMPInformationCache::RuntimeFunctionInfo &RFI,
1594                                Value *ReplVal = nullptr) {
1595     auto *UV = RFI.getUseVector(F);
1596     if (!UV || UV->size() + (ReplVal != nullptr) < 2)
1597       return false;
1598 
1599     LLVM_DEBUG(
1600         dbgs() << TAG << "Deduplicate " << UV->size() << " uses of " << RFI.Name
1601                << (ReplVal ? " with an existing value\n" : "\n") << "\n");
1602 
1603     assert((!ReplVal || (isa<Argument>(ReplVal) &&
1604                          cast<Argument>(ReplVal)->getParent() == &F)) &&
1605            "Unexpected replacement value!");
1606 
1607     // TODO: Use dominance to find a good position instead.
1608     auto CanBeMoved = [this](CallBase &CB) {
1609       unsigned NumArgs = CB.getNumArgOperands();
1610       if (NumArgs == 0)
1611         return true;
1612       if (CB.getArgOperand(0)->getType() != OMPInfoCache.OMPBuilder.IdentPtr)
1613         return false;
1614       for (unsigned u = 1; u < NumArgs; ++u)
1615         if (isa<Instruction>(CB.getArgOperand(u)))
1616           return false;
1617       return true;
1618     };
1619 
1620     if (!ReplVal) {
1621       for (Use *U : *UV)
1622         if (CallInst *CI = getCallIfRegularCall(*U, &RFI)) {
1623           if (!CanBeMoved(*CI))
1624             continue;
1625 
1626           // If the function is a kernel, dedup will move
1627           // the runtime call right after the kernel init callsite. Otherwise,
1628           // it will move it to the beginning of the caller function.
1629           if (isKernel(F)) {
1630             auto &KernelInitRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
1631             auto *KernelInitUV = KernelInitRFI.getUseVector(F);
1632 
1633             if (KernelInitUV->empty())
1634               continue;
1635 
1636             assert(KernelInitUV->size() == 1 &&
1637                    "Expected a single __kmpc_target_init in kernel\n");
1638 
1639             CallInst *KernelInitCI =
1640                 getCallIfRegularCall(*KernelInitUV->front(), &KernelInitRFI);
1641             assert(KernelInitCI &&
1642                    "Expected a call to __kmpc_target_init in kernel\n");
1643 
1644             CI->moveAfter(KernelInitCI);
1645           } else
1646             CI->moveBefore(&*F.getEntryBlock().getFirstInsertionPt());
1647           ReplVal = CI;
1648           break;
1649         }
1650       if (!ReplVal)
1651         return false;
1652     }
1653 
1654     // If we use a call as a replacement value we need to make sure the ident is
1655     // valid at the new location. For now we just pick a global one, either
1656     // existing and used by one of the calls, or created from scratch.
1657     if (CallBase *CI = dyn_cast<CallBase>(ReplVal)) {
1658       if (!CI->arg_empty() &&
1659           CI->getArgOperand(0)->getType() == OMPInfoCache.OMPBuilder.IdentPtr) {
1660         Value *Ident = getCombinedIdentFromCallUsesIn(RFI, F,
1661                                                       /* GlobalOnly */ true);
1662         CI->setArgOperand(0, Ident);
1663       }
1664     }
1665 
1666     bool Changed = false;
1667     auto ReplaceAndDeleteCB = [&](Use &U, Function &Caller) {
1668       CallInst *CI = getCallIfRegularCall(U, &RFI);
1669       if (!CI || CI == ReplVal || &F != &Caller)
1670         return false;
1671       assert(CI->getCaller() == &F && "Unexpected call!");
1672 
1673       auto Remark = [&](OptimizationRemark OR) {
1674         return OR << "OpenMP runtime call "
1675                   << ore::NV("OpenMPOptRuntime", RFI.Name) << " deduplicated.";
1676       };
1677       if (CI->getDebugLoc())
1678         emitRemark<OptimizationRemark>(CI, "OMP170", Remark);
1679       else
1680         emitRemark<OptimizationRemark>(&F, "OMP170", Remark);
1681 
1682       CGUpdater.removeCallSite(*CI);
1683       CI->replaceAllUsesWith(ReplVal);
1684       CI->eraseFromParent();
1685       ++NumOpenMPRuntimeCallsDeduplicated;
1686       Changed = true;
1687       return true;
1688     };
1689     RFI.foreachUse(SCC, ReplaceAndDeleteCB);
1690 
1691     return Changed;
1692   }
1693 
1694   /// Collect arguments that represent the global thread id in \p GTIdArgs.
1695   void collectGlobalThreadIdArguments(SmallSetVector<Value *, 16> &GTIdArgs) {
1696     // TODO: Below we basically perform a fixpoint iteration with a pessimistic
1697     //       initialization. We could define an AbstractAttribute instead and
1698     //       run the Attributor here once it can be run as an SCC pass.
1699 
1700     // Helper to check the argument \p ArgNo at all call sites of \p F for
1701     // a GTId.
1702     auto CallArgOpIsGTId = [&](Function &F, unsigned ArgNo, CallInst &RefCI) {
1703       if (!F.hasLocalLinkage())
1704         return false;
1705       for (Use &U : F.uses()) {
1706         if (CallInst *CI = getCallIfRegularCall(U)) {
1707           Value *ArgOp = CI->getArgOperand(ArgNo);
1708           if (CI == &RefCI || GTIdArgs.count(ArgOp) ||
1709               getCallIfRegularCall(
1710                   *ArgOp, &OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num]))
1711             continue;
1712         }
1713         return false;
1714       }
1715       return true;
1716     };
1717 
1718     // Helper to identify uses of a GTId as GTId arguments.
1719     auto AddUserArgs = [&](Value &GTId) {
1720       for (Use &U : GTId.uses())
1721         if (CallInst *CI = dyn_cast<CallInst>(U.getUser()))
1722           if (CI->isArgOperand(&U))
1723             if (Function *Callee = CI->getCalledFunction())
1724               if (CallArgOpIsGTId(*Callee, U.getOperandNo(), *CI))
1725                 GTIdArgs.insert(Callee->getArg(U.getOperandNo()));
1726     };
1727 
1728     // The argument users of __kmpc_global_thread_num calls are GTIds.
1729     OMPInformationCache::RuntimeFunctionInfo &GlobThreadNumRFI =
1730         OMPInfoCache.RFIs[OMPRTL___kmpc_global_thread_num];
1731 
1732     GlobThreadNumRFI.foreachUse(SCC, [&](Use &U, Function &F) {
1733       if (CallInst *CI = getCallIfRegularCall(U, &GlobThreadNumRFI))
1734         AddUserArgs(*CI);
1735       return false;
1736     });
1737 
1738     // Transitively search for more arguments by looking at the users of the
1739     // ones we know already. During the search the GTIdArgs vector is extended
1740     // so we cannot cache the size nor can we use a range based for.
1741     for (unsigned u = 0; u < GTIdArgs.size(); ++u)
1742       AddUserArgs(*GTIdArgs[u]);
1743   }
1744 
1745   /// Kernel (=GPU) optimizations and utility functions
1746   ///
1747   ///{{
1748 
1749   /// Check if \p F is a kernel, hence entry point for target offloading.
1750   bool isKernel(Function &F) { return OMPInfoCache.Kernels.count(&F); }
1751 
1752   /// Cache to remember the unique kernel for a function.
1753   DenseMap<Function *, Optional<Kernel>> UniqueKernelMap;
1754 
1755   /// Find the unique kernel that will execute \p F, if any.
1756   Kernel getUniqueKernelFor(Function &F);
1757 
1758   /// Find the unique kernel that will execute \p I, if any.
1759   Kernel getUniqueKernelFor(Instruction &I) {
1760     return getUniqueKernelFor(*I.getFunction());
1761   }
1762 
1763   /// Rewrite the device (=GPU) code state machine create in non-SPMD mode in
1764   /// the cases we can avoid taking the address of a function.
1765   bool rewriteDeviceCodeStateMachine();
1766 
1767   ///
1768   ///}}
1769 
1770   /// Emit a remark generically
1771   ///
1772   /// This template function can be used to generically emit a remark. The
1773   /// RemarkKind should be one of the following:
1774   ///   - OptimizationRemark to indicate a successful optimization attempt
1775   ///   - OptimizationRemarkMissed to report a failed optimization attempt
1776   ///   - OptimizationRemarkAnalysis to provide additional information about an
1777   ///     optimization attempt
1778   ///
1779   /// The remark is built using a callback function provided by the caller that
1780   /// takes a RemarkKind as input and returns a RemarkKind.
1781   template <typename RemarkKind, typename RemarkCallBack>
1782   void emitRemark(Instruction *I, StringRef RemarkName,
1783                   RemarkCallBack &&RemarkCB) const {
1784     Function *F = I->getParent()->getParent();
1785     auto &ORE = OREGetter(F);
1786 
1787     if (RemarkName.startswith("OMP"))
1788       ORE.emit([&]() {
1789         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I))
1790                << " [" << RemarkName << "]";
1791       });
1792     else
1793       ORE.emit(
1794           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, I)); });
1795   }
1796 
1797   /// Emit a remark on a function.
1798   template <typename RemarkKind, typename RemarkCallBack>
1799   void emitRemark(Function *F, StringRef RemarkName,
1800                   RemarkCallBack &&RemarkCB) const {
1801     auto &ORE = OREGetter(F);
1802 
1803     if (RemarkName.startswith("OMP"))
1804       ORE.emit([&]() {
1805         return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F))
1806                << " [" << RemarkName << "]";
1807       });
1808     else
1809       ORE.emit(
1810           [&]() { return RemarkCB(RemarkKind(DEBUG_TYPE, RemarkName, F)); });
1811   }
1812 
1813   /// RAII struct to temporarily change an RTL function's linkage to external.
1814   /// This prevents it from being mistakenly removed by other optimizations.
1815   struct ExternalizationRAII {
1816     ExternalizationRAII(OMPInformationCache &OMPInfoCache,
1817                         RuntimeFunction RFKind)
1818         : Declaration(OMPInfoCache.RFIs[RFKind].Declaration) {
1819       if (!Declaration)
1820         return;
1821 
1822       LinkageType = Declaration->getLinkage();
1823       Declaration->setLinkage(GlobalValue::ExternalLinkage);
1824     }
1825 
1826     ~ExternalizationRAII() {
1827       if (!Declaration)
1828         return;
1829 
1830       Declaration->setLinkage(LinkageType);
1831     }
1832 
1833     Function *Declaration;
1834     GlobalValue::LinkageTypes LinkageType;
1835   };
1836 
1837   /// The underlying module.
1838   Module &M;
1839 
1840   /// The SCC we are operating on.
1841   SmallVectorImpl<Function *> &SCC;
1842 
1843   /// Callback to update the call graph, the first argument is a removed call,
1844   /// the second an optional replacement call.
1845   CallGraphUpdater &CGUpdater;
1846 
1847   /// Callback to get an OptimizationRemarkEmitter from a Function *
1848   OptimizationRemarkGetter OREGetter;
1849 
1850   /// OpenMP-specific information cache. Also Used for Attributor runs.
1851   OMPInformationCache &OMPInfoCache;
1852 
1853   /// Attributor instance.
1854   Attributor &A;
1855 
1856   /// Helper function to run Attributor on SCC.
1857   bool runAttributor(bool IsModulePass) {
1858     if (SCC.empty())
1859       return false;
1860 
1861     // Temporarily make these function have external linkage so the Attributor
1862     // doesn't remove them when we try to look them up later.
1863     ExternalizationRAII Parallel(OMPInfoCache, OMPRTL___kmpc_kernel_parallel);
1864     ExternalizationRAII EndParallel(OMPInfoCache,
1865                                     OMPRTL___kmpc_kernel_end_parallel);
1866     ExternalizationRAII BarrierSPMD(OMPInfoCache,
1867                                     OMPRTL___kmpc_barrier_simple_spmd);
1868     ExternalizationRAII ThreadId(OMPInfoCache,
1869                                  OMPRTL___kmpc_get_hardware_thread_id_in_block);
1870 
1871     registerAAs(IsModulePass);
1872 
1873     ChangeStatus Changed = A.run();
1874 
1875     LLVM_DEBUG(dbgs() << "[Attributor] Done with " << SCC.size()
1876                       << " functions, result: " << Changed << ".\n");
1877 
1878     return Changed == ChangeStatus::CHANGED;
1879   }
1880 
1881   void registerFoldRuntimeCall(RuntimeFunction RF);
1882 
1883   /// Populate the Attributor with abstract attribute opportunities in the
1884   /// function.
1885   void registerAAs(bool IsModulePass);
1886 };
1887 
1888 Kernel OpenMPOpt::getUniqueKernelFor(Function &F) {
1889   if (!OMPInfoCache.ModuleSlice.count(&F))
1890     return nullptr;
1891 
1892   // Use a scope to keep the lifetime of the CachedKernel short.
1893   {
1894     Optional<Kernel> &CachedKernel = UniqueKernelMap[&F];
1895     if (CachedKernel)
1896       return *CachedKernel;
1897 
1898     // TODO: We should use an AA to create an (optimistic and callback
1899     //       call-aware) call graph. For now we stick to simple patterns that
1900     //       are less powerful, basically the worst fixpoint.
1901     if (isKernel(F)) {
1902       CachedKernel = Kernel(&F);
1903       return *CachedKernel;
1904     }
1905 
1906     CachedKernel = nullptr;
1907     if (!F.hasLocalLinkage()) {
1908 
1909       // See https://openmp.llvm.org/remarks/OptimizationRemarks.html
1910       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
1911         return ORA << "Potentially unknown OpenMP target region caller.";
1912       };
1913       emitRemark<OptimizationRemarkAnalysis>(&F, "OMP100", Remark);
1914 
1915       return nullptr;
1916     }
1917   }
1918 
1919   auto GetUniqueKernelForUse = [&](const Use &U) -> Kernel {
1920     if (auto *Cmp = dyn_cast<ICmpInst>(U.getUser())) {
1921       // Allow use in equality comparisons.
1922       if (Cmp->isEquality())
1923         return getUniqueKernelFor(*Cmp);
1924       return nullptr;
1925     }
1926     if (auto *CB = dyn_cast<CallBase>(U.getUser())) {
1927       // Allow direct calls.
1928       if (CB->isCallee(&U))
1929         return getUniqueKernelFor(*CB);
1930 
1931       OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1932           OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1933       // Allow the use in __kmpc_parallel_51 calls.
1934       if (OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI))
1935         return getUniqueKernelFor(*CB);
1936       return nullptr;
1937     }
1938     // Disallow every other use.
1939     return nullptr;
1940   };
1941 
1942   // TODO: In the future we want to track more than just a unique kernel.
1943   SmallPtrSet<Kernel, 2> PotentialKernels;
1944   OMPInformationCache::foreachUse(F, [&](const Use &U) {
1945     PotentialKernels.insert(GetUniqueKernelForUse(U));
1946   });
1947 
1948   Kernel K = nullptr;
1949   if (PotentialKernels.size() == 1)
1950     K = *PotentialKernels.begin();
1951 
1952   // Cache the result.
1953   UniqueKernelMap[&F] = K;
1954 
1955   return K;
1956 }
1957 
1958 bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
1959   OMPInformationCache::RuntimeFunctionInfo &KernelParallelRFI =
1960       OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
1961 
1962   bool Changed = false;
1963   if (!KernelParallelRFI)
1964     return Changed;
1965 
1966   // If we have disabled state machine changes, exit
1967   if (DisableOpenMPOptStateMachineRewrite)
1968     return Changed;
1969 
1970   for (Function *F : SCC) {
1971 
1972     // Check if the function is a use in a __kmpc_parallel_51 call at
1973     // all.
1974     bool UnknownUse = false;
1975     bool KernelParallelUse = false;
1976     unsigned NumDirectCalls = 0;
1977 
1978     SmallVector<Use *, 2> ToBeReplacedStateMachineUses;
1979     OMPInformationCache::foreachUse(*F, [&](Use &U) {
1980       if (auto *CB = dyn_cast<CallBase>(U.getUser()))
1981         if (CB->isCallee(&U)) {
1982           ++NumDirectCalls;
1983           return;
1984         }
1985 
1986       if (isa<ICmpInst>(U.getUser())) {
1987         ToBeReplacedStateMachineUses.push_back(&U);
1988         return;
1989       }
1990 
1991       // Find wrapper functions that represent parallel kernels.
1992       CallInst *CI =
1993           OpenMPOpt::getCallIfRegularCall(*U.getUser(), &KernelParallelRFI);
1994       const unsigned int WrapperFunctionArgNo = 6;
1995       if (!KernelParallelUse && CI &&
1996           CI->getArgOperandNo(&U) == WrapperFunctionArgNo) {
1997         KernelParallelUse = true;
1998         ToBeReplacedStateMachineUses.push_back(&U);
1999         return;
2000       }
2001       UnknownUse = true;
2002     });
2003 
2004     // Do not emit a remark if we haven't seen a __kmpc_parallel_51
2005     // use.
2006     if (!KernelParallelUse)
2007       continue;
2008 
2009     // If this ever hits, we should investigate.
2010     // TODO: Checking the number of uses is not a necessary restriction and
2011     // should be lifted.
2012     if (UnknownUse || NumDirectCalls != 1 ||
2013         ToBeReplacedStateMachineUses.size() > 2) {
2014       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2015         return ORA << "Parallel region is used in "
2016                    << (UnknownUse ? "unknown" : "unexpected")
2017                    << " ways. Will not attempt to rewrite the state machine.";
2018       };
2019       emitRemark<OptimizationRemarkAnalysis>(F, "OMP101", Remark);
2020       continue;
2021     }
2022 
2023     // Even if we have __kmpc_parallel_51 calls, we (for now) give
2024     // up if the function is not called from a unique kernel.
2025     Kernel K = getUniqueKernelFor(*F);
2026     if (!K) {
2027       auto Remark = [&](OptimizationRemarkAnalysis ORA) {
2028         return ORA << "Parallel region is not called from a unique kernel. "
2029                       "Will not attempt to rewrite the state machine.";
2030       };
2031       emitRemark<OptimizationRemarkAnalysis>(F, "OMP102", Remark);
2032       continue;
2033     }
2034 
2035     // We now know F is a parallel body function called only from the kernel K.
2036     // We also identified the state machine uses in which we replace the
2037     // function pointer by a new global symbol for identification purposes. This
2038     // ensures only direct calls to the function are left.
2039 
2040     Module &M = *F->getParent();
2041     Type *Int8Ty = Type::getInt8Ty(M.getContext());
2042 
2043     auto *ID = new GlobalVariable(
2044         M, Int8Ty, /* isConstant */ true, GlobalValue::PrivateLinkage,
2045         UndefValue::get(Int8Ty), F->getName() + ".ID");
2046 
2047     for (Use *U : ToBeReplacedStateMachineUses)
2048       U->set(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
2049           ID, U->get()->getType()));
2050 
2051     ++NumOpenMPParallelRegionsReplacedInGPUStateMachine;
2052 
2053     Changed = true;
2054   }
2055 
2056   return Changed;
2057 }
2058 
2059 /// Abstract Attribute for tracking ICV values.
2060 struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
2061   using Base = StateWrapper<BooleanState, AbstractAttribute>;
2062   AAICVTracker(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2063 
2064   void initialize(Attributor &A) override {
2065     Function *F = getAnchorScope();
2066     if (!F || !A.isFunctionIPOAmendable(*F))
2067       indicatePessimisticFixpoint();
2068   }
2069 
2070   /// Returns true if value is assumed to be tracked.
2071   bool isAssumedTracked() const { return getAssumed(); }
2072 
2073   /// Returns true if value is known to be tracked.
2074   bool isKnownTracked() const { return getAssumed(); }
2075 
2076   /// Create an abstract attribute biew for the position \p IRP.
2077   static AAICVTracker &createForPosition(const IRPosition &IRP, Attributor &A);
2078 
2079   /// Return the value with which \p I can be replaced for specific \p ICV.
2080   virtual Optional<Value *> getReplacementValue(InternalControlVar ICV,
2081                                                 const Instruction *I,
2082                                                 Attributor &A) const {
2083     return None;
2084   }
2085 
2086   /// Return an assumed unique ICV value if a single candidate is found. If
2087   /// there cannot be one, return a nullptr. If it is not clear yet, return the
2088   /// Optional::NoneType.
2089   virtual Optional<Value *>
2090   getUniqueReplacementValue(InternalControlVar ICV) const = 0;
2091 
2092   // Currently only nthreads is being tracked.
2093   // this array will only grow with time.
2094   InternalControlVar TrackableICVs[1] = {ICV_nthreads};
2095 
2096   /// See AbstractAttribute::getName()
2097   const std::string getName() const override { return "AAICVTracker"; }
2098 
2099   /// See AbstractAttribute::getIdAddr()
2100   const char *getIdAddr() const override { return &ID; }
2101 
2102   /// This function should return true if the type of the \p AA is AAICVTracker
2103   static bool classof(const AbstractAttribute *AA) {
2104     return (AA->getIdAddr() == &ID);
2105   }
2106 
2107   static const char ID;
2108 };
2109 
2110 struct AAICVTrackerFunction : public AAICVTracker {
2111   AAICVTrackerFunction(const IRPosition &IRP, Attributor &A)
2112       : AAICVTracker(IRP, A) {}
2113 
2114   // FIXME: come up with better string.
2115   const std::string getAsStr() const override { return "ICVTrackerFunction"; }
2116 
2117   // FIXME: come up with some stats.
2118   void trackStatistics() const override {}
2119 
2120   /// We don't manifest anything for this AA.
2121   ChangeStatus manifest(Attributor &A) override {
2122     return ChangeStatus::UNCHANGED;
2123   }
2124 
2125   // Map of ICV to their values at specific program point.
2126   EnumeratedArray<DenseMap<Instruction *, Value *>, InternalControlVar,
2127                   InternalControlVar::ICV___last>
2128       ICVReplacementValuesMap;
2129 
2130   ChangeStatus updateImpl(Attributor &A) override {
2131     ChangeStatus HasChanged = ChangeStatus::UNCHANGED;
2132 
2133     Function *F = getAnchorScope();
2134 
2135     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2136 
2137     for (InternalControlVar ICV : TrackableICVs) {
2138       auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2139 
2140       auto &ValuesMap = ICVReplacementValuesMap[ICV];
2141       auto TrackValues = [&](Use &U, Function &) {
2142         CallInst *CI = OpenMPOpt::getCallIfRegularCall(U);
2143         if (!CI)
2144           return false;
2145 
2146         // FIXME: handle setters with more that 1 arguments.
2147         /// Track new value.
2148         if (ValuesMap.insert(std::make_pair(CI, CI->getArgOperand(0))).second)
2149           HasChanged = ChangeStatus::CHANGED;
2150 
2151         return false;
2152       };
2153 
2154       auto CallCheck = [&](Instruction &I) {
2155         Optional<Value *> ReplVal = getValueForCall(A, &I, ICV);
2156         if (ReplVal.hasValue() &&
2157             ValuesMap.insert(std::make_pair(&I, *ReplVal)).second)
2158           HasChanged = ChangeStatus::CHANGED;
2159 
2160         return true;
2161       };
2162 
2163       // Track all changes of an ICV.
2164       SetterRFI.foreachUse(TrackValues, F);
2165 
2166       bool UsedAssumedInformation = false;
2167       A.checkForAllInstructions(CallCheck, *this, {Instruction::Call},
2168                                 UsedAssumedInformation,
2169                                 /* CheckBBLivenessOnly */ true);
2170 
2171       /// TODO: Figure out a way to avoid adding entry in
2172       /// ICVReplacementValuesMap
2173       Instruction *Entry = &F->getEntryBlock().front();
2174       if (HasChanged == ChangeStatus::CHANGED && !ValuesMap.count(Entry))
2175         ValuesMap.insert(std::make_pair(Entry, nullptr));
2176     }
2177 
2178     return HasChanged;
2179   }
2180 
2181   /// Hepler to check if \p I is a call and get the value for it if it is
2182   /// unique.
2183   Optional<Value *> getValueForCall(Attributor &A, const Instruction *I,
2184                                     InternalControlVar &ICV) const {
2185 
2186     const auto *CB = dyn_cast<CallBase>(I);
2187     if (!CB || CB->hasFnAttr("no_openmp") ||
2188         CB->hasFnAttr("no_openmp_routines"))
2189       return None;
2190 
2191     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2192     auto &GetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Getter];
2193     auto &SetterRFI = OMPInfoCache.RFIs[OMPInfoCache.ICVs[ICV].Setter];
2194     Function *CalledFunction = CB->getCalledFunction();
2195 
2196     // Indirect call, assume ICV changes.
2197     if (CalledFunction == nullptr)
2198       return nullptr;
2199     if (CalledFunction == GetterRFI.Declaration)
2200       return None;
2201     if (CalledFunction == SetterRFI.Declaration) {
2202       if (ICVReplacementValuesMap[ICV].count(I))
2203         return ICVReplacementValuesMap[ICV].lookup(I);
2204 
2205       return nullptr;
2206     }
2207 
2208     // Since we don't know, assume it changes the ICV.
2209     if (CalledFunction->isDeclaration())
2210       return nullptr;
2211 
2212     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2213         *this, IRPosition::callsite_returned(*CB), DepClassTy::REQUIRED);
2214 
2215     if (ICVTrackingAA.isAssumedTracked())
2216       return ICVTrackingAA.getUniqueReplacementValue(ICV);
2217 
2218     // If we don't know, assume it changes.
2219     return nullptr;
2220   }
2221 
2222   // We don't check unique value for a function, so return None.
2223   Optional<Value *>
2224   getUniqueReplacementValue(InternalControlVar ICV) const override {
2225     return None;
2226   }
2227 
2228   /// Return the value with which \p I can be replaced for specific \p ICV.
2229   Optional<Value *> getReplacementValue(InternalControlVar ICV,
2230                                         const Instruction *I,
2231                                         Attributor &A) const override {
2232     const auto &ValuesMap = ICVReplacementValuesMap[ICV];
2233     if (ValuesMap.count(I))
2234       return ValuesMap.lookup(I);
2235 
2236     SmallVector<const Instruction *, 16> Worklist;
2237     SmallPtrSet<const Instruction *, 16> Visited;
2238     Worklist.push_back(I);
2239 
2240     Optional<Value *> ReplVal;
2241 
2242     while (!Worklist.empty()) {
2243       const Instruction *CurrInst = Worklist.pop_back_val();
2244       if (!Visited.insert(CurrInst).second)
2245         continue;
2246 
2247       const BasicBlock *CurrBB = CurrInst->getParent();
2248 
2249       // Go up and look for all potential setters/calls that might change the
2250       // ICV.
2251       while ((CurrInst = CurrInst->getPrevNode())) {
2252         if (ValuesMap.count(CurrInst)) {
2253           Optional<Value *> NewReplVal = ValuesMap.lookup(CurrInst);
2254           // Unknown value, track new.
2255           if (!ReplVal.hasValue()) {
2256             ReplVal = NewReplVal;
2257             break;
2258           }
2259 
2260           // If we found a new value, we can't know the icv value anymore.
2261           if (NewReplVal.hasValue())
2262             if (ReplVal != NewReplVal)
2263               return nullptr;
2264 
2265           break;
2266         }
2267 
2268         Optional<Value *> NewReplVal = getValueForCall(A, CurrInst, ICV);
2269         if (!NewReplVal.hasValue())
2270           continue;
2271 
2272         // Unknown value, track new.
2273         if (!ReplVal.hasValue()) {
2274           ReplVal = NewReplVal;
2275           break;
2276         }
2277 
2278         // if (NewReplVal.hasValue())
2279         // We found a new value, we can't know the icv value anymore.
2280         if (ReplVal != NewReplVal)
2281           return nullptr;
2282       }
2283 
2284       // If we are in the same BB and we have a value, we are done.
2285       if (CurrBB == I->getParent() && ReplVal.hasValue())
2286         return ReplVal;
2287 
2288       // Go through all predecessors and add terminators for analysis.
2289       for (const BasicBlock *Pred : predecessors(CurrBB))
2290         if (const Instruction *Terminator = Pred->getTerminator())
2291           Worklist.push_back(Terminator);
2292     }
2293 
2294     return ReplVal;
2295   }
2296 };
2297 
2298 struct AAICVTrackerFunctionReturned : AAICVTracker {
2299   AAICVTrackerFunctionReturned(const IRPosition &IRP, Attributor &A)
2300       : AAICVTracker(IRP, A) {}
2301 
2302   // FIXME: come up with better string.
2303   const std::string getAsStr() const override {
2304     return "ICVTrackerFunctionReturned";
2305   }
2306 
2307   // FIXME: come up with some stats.
2308   void trackStatistics() const override {}
2309 
2310   /// We don't manifest anything for this AA.
2311   ChangeStatus manifest(Attributor &A) override {
2312     return ChangeStatus::UNCHANGED;
2313   }
2314 
2315   // Map of ICV to their values at specific program point.
2316   EnumeratedArray<Optional<Value *>, InternalControlVar,
2317                   InternalControlVar::ICV___last>
2318       ICVReplacementValuesMap;
2319 
2320   /// Return the value with which \p I can be replaced for specific \p ICV.
2321   Optional<Value *>
2322   getUniqueReplacementValue(InternalControlVar ICV) const override {
2323     return ICVReplacementValuesMap[ICV];
2324   }
2325 
2326   ChangeStatus updateImpl(Attributor &A) override {
2327     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2328     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2329         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2330 
2331     if (!ICVTrackingAA.isAssumedTracked())
2332       return indicatePessimisticFixpoint();
2333 
2334     for (InternalControlVar ICV : TrackableICVs) {
2335       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2336       Optional<Value *> UniqueICVValue;
2337 
2338       auto CheckReturnInst = [&](Instruction &I) {
2339         Optional<Value *> NewReplVal =
2340             ICVTrackingAA.getReplacementValue(ICV, &I, A);
2341 
2342         // If we found a second ICV value there is no unique returned value.
2343         if (UniqueICVValue.hasValue() && UniqueICVValue != NewReplVal)
2344           return false;
2345 
2346         UniqueICVValue = NewReplVal;
2347 
2348         return true;
2349       };
2350 
2351       bool UsedAssumedInformation = false;
2352       if (!A.checkForAllInstructions(CheckReturnInst, *this, {Instruction::Ret},
2353                                      UsedAssumedInformation,
2354                                      /* CheckBBLivenessOnly */ true))
2355         UniqueICVValue = nullptr;
2356 
2357       if (UniqueICVValue == ReplVal)
2358         continue;
2359 
2360       ReplVal = UniqueICVValue;
2361       Changed = ChangeStatus::CHANGED;
2362     }
2363 
2364     return Changed;
2365   }
2366 };
2367 
2368 struct AAICVTrackerCallSite : AAICVTracker {
2369   AAICVTrackerCallSite(const IRPosition &IRP, Attributor &A)
2370       : AAICVTracker(IRP, A) {}
2371 
2372   void initialize(Attributor &A) override {
2373     Function *F = getAnchorScope();
2374     if (!F || !A.isFunctionIPOAmendable(*F))
2375       indicatePessimisticFixpoint();
2376 
2377     // We only initialize this AA for getters, so we need to know which ICV it
2378     // gets.
2379     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2380     for (InternalControlVar ICV : TrackableICVs) {
2381       auto ICVInfo = OMPInfoCache.ICVs[ICV];
2382       auto &Getter = OMPInfoCache.RFIs[ICVInfo.Getter];
2383       if (Getter.Declaration == getAssociatedFunction()) {
2384         AssociatedICV = ICVInfo.Kind;
2385         return;
2386       }
2387     }
2388 
2389     /// Unknown ICV.
2390     indicatePessimisticFixpoint();
2391   }
2392 
2393   ChangeStatus manifest(Attributor &A) override {
2394     if (!ReplVal.hasValue() || !ReplVal.getValue())
2395       return ChangeStatus::UNCHANGED;
2396 
2397     A.changeValueAfterManifest(*getCtxI(), **ReplVal);
2398     A.deleteAfterManifest(*getCtxI());
2399 
2400     return ChangeStatus::CHANGED;
2401   }
2402 
2403   // FIXME: come up with better string.
2404   const std::string getAsStr() const override { return "ICVTrackerCallSite"; }
2405 
2406   // FIXME: come up with some stats.
2407   void trackStatistics() const override {}
2408 
2409   InternalControlVar AssociatedICV;
2410   Optional<Value *> ReplVal;
2411 
2412   ChangeStatus updateImpl(Attributor &A) override {
2413     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2414         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
2415 
2416     // We don't have any information, so we assume it changes the ICV.
2417     if (!ICVTrackingAA.isAssumedTracked())
2418       return indicatePessimisticFixpoint();
2419 
2420     Optional<Value *> NewReplVal =
2421         ICVTrackingAA.getReplacementValue(AssociatedICV, getCtxI(), A);
2422 
2423     if (ReplVal == NewReplVal)
2424       return ChangeStatus::UNCHANGED;
2425 
2426     ReplVal = NewReplVal;
2427     return ChangeStatus::CHANGED;
2428   }
2429 
2430   // Return the value with which associated value can be replaced for specific
2431   // \p ICV.
2432   Optional<Value *>
2433   getUniqueReplacementValue(InternalControlVar ICV) const override {
2434     return ReplVal;
2435   }
2436 };
2437 
2438 struct AAICVTrackerCallSiteReturned : AAICVTracker {
2439   AAICVTrackerCallSiteReturned(const IRPosition &IRP, Attributor &A)
2440       : AAICVTracker(IRP, A) {}
2441 
2442   // FIXME: come up with better string.
2443   const std::string getAsStr() const override {
2444     return "ICVTrackerCallSiteReturned";
2445   }
2446 
2447   // FIXME: come up with some stats.
2448   void trackStatistics() const override {}
2449 
2450   /// We don't manifest anything for this AA.
2451   ChangeStatus manifest(Attributor &A) override {
2452     return ChangeStatus::UNCHANGED;
2453   }
2454 
2455   // Map of ICV to their values at specific program point.
2456   EnumeratedArray<Optional<Value *>, InternalControlVar,
2457                   InternalControlVar::ICV___last>
2458       ICVReplacementValuesMap;
2459 
2460   /// Return the value with which associated value can be replaced for specific
2461   /// \p ICV.
2462   Optional<Value *>
2463   getUniqueReplacementValue(InternalControlVar ICV) const override {
2464     return ICVReplacementValuesMap[ICV];
2465   }
2466 
2467   ChangeStatus updateImpl(Attributor &A) override {
2468     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2469     const auto &ICVTrackingAA = A.getAAFor<AAICVTracker>(
2470         *this, IRPosition::returned(*getAssociatedFunction()),
2471         DepClassTy::REQUIRED);
2472 
2473     // We don't have any information, so we assume it changes the ICV.
2474     if (!ICVTrackingAA.isAssumedTracked())
2475       return indicatePessimisticFixpoint();
2476 
2477     for (InternalControlVar ICV : TrackableICVs) {
2478       Optional<Value *> &ReplVal = ICVReplacementValuesMap[ICV];
2479       Optional<Value *> NewReplVal =
2480           ICVTrackingAA.getUniqueReplacementValue(ICV);
2481 
2482       if (ReplVal == NewReplVal)
2483         continue;
2484 
2485       ReplVal = NewReplVal;
2486       Changed = ChangeStatus::CHANGED;
2487     }
2488     return Changed;
2489   }
2490 };
2491 
2492 struct AAExecutionDomainFunction : public AAExecutionDomain {
2493   AAExecutionDomainFunction(const IRPosition &IRP, Attributor &A)
2494       : AAExecutionDomain(IRP, A) {}
2495 
2496   const std::string getAsStr() const override {
2497     return "[AAExecutionDomain] " + std::to_string(SingleThreadedBBs.size()) +
2498            "/" + std::to_string(NumBBs) + " BBs thread 0 only.";
2499   }
2500 
2501   /// See AbstractAttribute::trackStatistics().
2502   void trackStatistics() const override {}
2503 
2504   void initialize(Attributor &A) override {
2505     Function *F = getAnchorScope();
2506     for (const auto &BB : *F)
2507       SingleThreadedBBs.insert(&BB);
2508     NumBBs = SingleThreadedBBs.size();
2509   }
2510 
2511   ChangeStatus manifest(Attributor &A) override {
2512     LLVM_DEBUG({
2513       for (const BasicBlock *BB : SingleThreadedBBs)
2514         dbgs() << TAG << " Basic block @" << getAnchorScope()->getName() << " "
2515                << BB->getName() << " is executed by a single thread.\n";
2516     });
2517     return ChangeStatus::UNCHANGED;
2518   }
2519 
2520   ChangeStatus updateImpl(Attributor &A) override;
2521 
2522   /// Check if an instruction is executed by a single thread.
2523   bool isExecutedByInitialThreadOnly(const Instruction &I) const override {
2524     return isExecutedByInitialThreadOnly(*I.getParent());
2525   }
2526 
2527   bool isExecutedByInitialThreadOnly(const BasicBlock &BB) const override {
2528     return isValidState() && SingleThreadedBBs.contains(&BB);
2529   }
2530 
2531   /// Set of basic blocks that are executed by a single thread.
2532   DenseSet<const BasicBlock *> SingleThreadedBBs;
2533 
2534   /// Total number of basic blocks in this function.
2535   long unsigned NumBBs;
2536 };
2537 
2538 ChangeStatus AAExecutionDomainFunction::updateImpl(Attributor &A) {
2539   Function *F = getAnchorScope();
2540   ReversePostOrderTraversal<Function *> RPOT(F);
2541   auto NumSingleThreadedBBs = SingleThreadedBBs.size();
2542 
2543   bool AllCallSitesKnown;
2544   auto PredForCallSite = [&](AbstractCallSite ACS) {
2545     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
2546         *this, IRPosition::function(*ACS.getInstruction()->getFunction()),
2547         DepClassTy::REQUIRED);
2548     return ACS.isDirectCall() &&
2549            ExecutionDomainAA.isExecutedByInitialThreadOnly(
2550                *ACS.getInstruction());
2551   };
2552 
2553   if (!A.checkForAllCallSites(PredForCallSite, *this,
2554                               /* RequiresAllCallSites */ true,
2555                               AllCallSitesKnown))
2556     SingleThreadedBBs.erase(&F->getEntryBlock());
2557 
2558   auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2559   auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2560 
2561   // Check if the edge into the successor block contains a condition that only
2562   // lets the main thread execute it.
2563   auto IsInitialThreadOnly = [&](BranchInst *Edge, BasicBlock *SuccessorBB) {
2564     if (!Edge || !Edge->isConditional())
2565       return false;
2566     if (Edge->getSuccessor(0) != SuccessorBB)
2567       return false;
2568 
2569     auto *Cmp = dyn_cast<CmpInst>(Edge->getCondition());
2570     if (!Cmp || !Cmp->isTrueWhenEqual() || !Cmp->isEquality())
2571       return false;
2572 
2573     ConstantInt *C = dyn_cast<ConstantInt>(Cmp->getOperand(1));
2574     if (!C)
2575       return false;
2576 
2577     // Match: -1 == __kmpc_target_init (for non-SPMD kernels only!)
2578     if (C->isAllOnesValue()) {
2579       auto *CB = dyn_cast<CallBase>(Cmp->getOperand(0));
2580       CB = CB ? OpenMPOpt::getCallIfRegularCall(*CB, &RFI) : nullptr;
2581       if (!CB)
2582         return false;
2583       const int InitIsSPMDArgNo = 1;
2584       auto *IsSPMDModeCI =
2585           dyn_cast<ConstantInt>(CB->getOperand(InitIsSPMDArgNo));
2586       return IsSPMDModeCI && IsSPMDModeCI->isZero();
2587     }
2588 
2589     if (C->isZero()) {
2590       // Match: 0 == llvm.nvvm.read.ptx.sreg.tid.x()
2591       if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2592         if (II->getIntrinsicID() == Intrinsic::nvvm_read_ptx_sreg_tid_x)
2593           return true;
2594 
2595       // Match: 0 == llvm.amdgcn.workitem.id.x()
2596       if (auto *II = dyn_cast<IntrinsicInst>(Cmp->getOperand(0)))
2597         if (II->getIntrinsicID() == Intrinsic::amdgcn_workitem_id_x)
2598           return true;
2599     }
2600 
2601     return false;
2602   };
2603 
2604   // Merge all the predecessor states into the current basic block. A basic
2605   // block is executed by a single thread if all of its predecessors are.
2606   auto MergePredecessorStates = [&](BasicBlock *BB) {
2607     if (pred_begin(BB) == pred_end(BB))
2608       return SingleThreadedBBs.contains(BB);
2609 
2610     bool IsInitialThread = true;
2611     for (auto PredBB = pred_begin(BB), PredEndBB = pred_end(BB);
2612          PredBB != PredEndBB; ++PredBB) {
2613       if (!IsInitialThreadOnly(dyn_cast<BranchInst>((*PredBB)->getTerminator()),
2614                                BB))
2615         IsInitialThread &= SingleThreadedBBs.contains(*PredBB);
2616     }
2617 
2618     return IsInitialThread;
2619   };
2620 
2621   for (auto *BB : RPOT) {
2622     if (!MergePredecessorStates(BB))
2623       SingleThreadedBBs.erase(BB);
2624   }
2625 
2626   return (NumSingleThreadedBBs == SingleThreadedBBs.size())
2627              ? ChangeStatus::UNCHANGED
2628              : ChangeStatus::CHANGED;
2629 }
2630 
2631 /// Try to replace memory allocation calls called by a single thread with a
2632 /// static buffer of shared memory.
2633 struct AAHeapToShared : public StateWrapper<BooleanState, AbstractAttribute> {
2634   using Base = StateWrapper<BooleanState, AbstractAttribute>;
2635   AAHeapToShared(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2636 
2637   /// Create an abstract attribute view for the position \p IRP.
2638   static AAHeapToShared &createForPosition(const IRPosition &IRP,
2639                                            Attributor &A);
2640 
2641   /// Returns true if HeapToShared conversion is assumed to be possible.
2642   virtual bool isAssumedHeapToShared(CallBase &CB) const = 0;
2643 
2644   /// Returns true if HeapToShared conversion is assumed and the CB is a
2645   /// callsite to a free operation to be removed.
2646   virtual bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const = 0;
2647 
2648   /// See AbstractAttribute::getName().
2649   const std::string getName() const override { return "AAHeapToShared"; }
2650 
2651   /// See AbstractAttribute::getIdAddr().
2652   const char *getIdAddr() const override { return &ID; }
2653 
2654   /// This function should return true if the type of the \p AA is
2655   /// AAHeapToShared.
2656   static bool classof(const AbstractAttribute *AA) {
2657     return (AA->getIdAddr() == &ID);
2658   }
2659 
2660   /// Unique ID (due to the unique address)
2661   static const char ID;
2662 };
2663 
2664 struct AAHeapToSharedFunction : public AAHeapToShared {
2665   AAHeapToSharedFunction(const IRPosition &IRP, Attributor &A)
2666       : AAHeapToShared(IRP, A) {}
2667 
2668   const std::string getAsStr() const override {
2669     return "[AAHeapToShared] " + std::to_string(MallocCalls.size()) +
2670            " malloc calls eligible.";
2671   }
2672 
2673   /// See AbstractAttribute::trackStatistics().
2674   void trackStatistics() const override {}
2675 
2676   /// This functions finds free calls that will be removed by the
2677   /// HeapToShared transformation.
2678   void findPotentialRemovedFreeCalls(Attributor &A) {
2679     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2680     auto &FreeRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2681 
2682     PotentialRemovedFreeCalls.clear();
2683     // Update free call users of found malloc calls.
2684     for (CallBase *CB : MallocCalls) {
2685       SmallVector<CallBase *, 4> FreeCalls;
2686       for (auto *U : CB->users()) {
2687         CallBase *C = dyn_cast<CallBase>(U);
2688         if (C && C->getCalledFunction() == FreeRFI.Declaration)
2689           FreeCalls.push_back(C);
2690       }
2691 
2692       if (FreeCalls.size() != 1)
2693         continue;
2694 
2695       PotentialRemovedFreeCalls.insert(FreeCalls.front());
2696     }
2697   }
2698 
2699   void initialize(Attributor &A) override {
2700     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2701     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2702 
2703     for (User *U : RFI.Declaration->users())
2704       if (CallBase *CB = dyn_cast<CallBase>(U))
2705         MallocCalls.insert(CB);
2706 
2707     findPotentialRemovedFreeCalls(A);
2708   }
2709 
2710   bool isAssumedHeapToShared(CallBase &CB) const override {
2711     return isValidState() && MallocCalls.count(&CB);
2712   }
2713 
2714   bool isAssumedHeapToSharedRemovedFree(CallBase &CB) const override {
2715     return isValidState() && PotentialRemovedFreeCalls.count(&CB);
2716   }
2717 
2718   ChangeStatus manifest(Attributor &A) override {
2719     if (MallocCalls.empty())
2720       return ChangeStatus::UNCHANGED;
2721 
2722     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2723     auto &FreeCall = OMPInfoCache.RFIs[OMPRTL___kmpc_free_shared];
2724 
2725     Function *F = getAnchorScope();
2726     auto *HS = A.lookupAAFor<AAHeapToStack>(IRPosition::function(*F), this,
2727                                             DepClassTy::OPTIONAL);
2728 
2729     ChangeStatus Changed = ChangeStatus::UNCHANGED;
2730     for (CallBase *CB : MallocCalls) {
2731       // Skip replacing this if HeapToStack has already claimed it.
2732       if (HS && HS->isAssumedHeapToStack(*CB))
2733         continue;
2734 
2735       // Find the unique free call to remove it.
2736       SmallVector<CallBase *, 4> FreeCalls;
2737       for (auto *U : CB->users()) {
2738         CallBase *C = dyn_cast<CallBase>(U);
2739         if (C && C->getCalledFunction() == FreeCall.Declaration)
2740           FreeCalls.push_back(C);
2741       }
2742       if (FreeCalls.size() != 1)
2743         continue;
2744 
2745       ConstantInt *AllocSize = dyn_cast<ConstantInt>(CB->getArgOperand(0));
2746 
2747       LLVM_DEBUG(dbgs() << TAG << "Replace globalization call " << *CB
2748                         << " with " << AllocSize->getZExtValue()
2749                         << " bytes of shared memory\n");
2750 
2751       // Create a new shared memory buffer of the same size as the allocation
2752       // and replace all the uses of the original allocation with it.
2753       Module *M = CB->getModule();
2754       Type *Int8Ty = Type::getInt8Ty(M->getContext());
2755       Type *Int8ArrTy = ArrayType::get(Int8Ty, AllocSize->getZExtValue());
2756       auto *SharedMem = new GlobalVariable(
2757           *M, Int8ArrTy, /* IsConstant */ false, GlobalValue::InternalLinkage,
2758           UndefValue::get(Int8ArrTy), CB->getName(), nullptr,
2759           GlobalValue::NotThreadLocal,
2760           static_cast<unsigned>(AddressSpace::Shared));
2761       auto *NewBuffer =
2762           ConstantExpr::getPointerCast(SharedMem, Int8Ty->getPointerTo());
2763 
2764       auto Remark = [&](OptimizationRemark OR) {
2765         return OR << "Replaced globalized variable with "
2766                   << ore::NV("SharedMemory", AllocSize->getZExtValue())
2767                   << ((AllocSize->getZExtValue() != 1) ? " bytes " : " byte ")
2768                   << "of shared memory.";
2769       };
2770       A.emitRemark<OptimizationRemark>(CB, "OMP111", Remark);
2771 
2772       SharedMem->setAlignment(MaybeAlign(32));
2773 
2774       A.changeValueAfterManifest(*CB, *NewBuffer);
2775       A.deleteAfterManifest(*CB);
2776       A.deleteAfterManifest(*FreeCalls.front());
2777 
2778       NumBytesMovedToSharedMemory += AllocSize->getZExtValue();
2779       Changed = ChangeStatus::CHANGED;
2780     }
2781 
2782     return Changed;
2783   }
2784 
2785   ChangeStatus updateImpl(Attributor &A) override {
2786     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2787     auto &RFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
2788     Function *F = getAnchorScope();
2789 
2790     auto NumMallocCalls = MallocCalls.size();
2791 
2792     // Only consider malloc calls executed by a single thread with a constant.
2793     for (User *U : RFI.Declaration->users()) {
2794       const auto &ED = A.getAAFor<AAExecutionDomain>(
2795           *this, IRPosition::function(*F), DepClassTy::REQUIRED);
2796       if (CallBase *CB = dyn_cast<CallBase>(U))
2797         if (!dyn_cast<ConstantInt>(CB->getArgOperand(0)) ||
2798             !ED.isExecutedByInitialThreadOnly(*CB))
2799           MallocCalls.erase(CB);
2800     }
2801 
2802     findPotentialRemovedFreeCalls(A);
2803 
2804     if (NumMallocCalls != MallocCalls.size())
2805       return ChangeStatus::CHANGED;
2806 
2807     return ChangeStatus::UNCHANGED;
2808   }
2809 
2810   /// Collection of all malloc calls in a function.
2811   SmallPtrSet<CallBase *, 4> MallocCalls;
2812   /// Collection of potentially removed free calls in a function.
2813   SmallPtrSet<CallBase *, 4> PotentialRemovedFreeCalls;
2814 };
2815 
2816 struct AAKernelInfo : public StateWrapper<KernelInfoState, AbstractAttribute> {
2817   using Base = StateWrapper<KernelInfoState, AbstractAttribute>;
2818   AAKernelInfo(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
2819 
2820   /// Statistics are tracked as part of manifest for now.
2821   void trackStatistics() const override {}
2822 
2823   /// See AbstractAttribute::getAsStr()
2824   const std::string getAsStr() const override {
2825     if (!isValidState())
2826       return "<invalid>";
2827     return std::string(SPMDCompatibilityTracker.isAssumed() ? "SPMD"
2828                                                             : "generic") +
2829            std::string(SPMDCompatibilityTracker.isAtFixpoint() ? " [FIX]"
2830                                                                : "") +
2831            std::string(" #PRs: ") +
2832            std::to_string(ReachedKnownParallelRegions.size()) +
2833            ", #Unknown PRs: " +
2834            std::to_string(ReachedUnknownParallelRegions.size());
2835   }
2836 
2837   /// Create an abstract attribute biew for the position \p IRP.
2838   static AAKernelInfo &createForPosition(const IRPosition &IRP, Attributor &A);
2839 
2840   /// See AbstractAttribute::getName()
2841   const std::string getName() const override { return "AAKernelInfo"; }
2842 
2843   /// See AbstractAttribute::getIdAddr()
2844   const char *getIdAddr() const override { return &ID; }
2845 
2846   /// This function should return true if the type of the \p AA is AAKernelInfo
2847   static bool classof(const AbstractAttribute *AA) {
2848     return (AA->getIdAddr() == &ID);
2849   }
2850 
2851   static const char ID;
2852 };
2853 
2854 /// The function kernel info abstract attribute, basically, what can we say
2855 /// about a function with regards to the KernelInfoState.
2856 struct AAKernelInfoFunction : AAKernelInfo {
2857   AAKernelInfoFunction(const IRPosition &IRP, Attributor &A)
2858       : AAKernelInfo(IRP, A) {}
2859 
2860   SmallPtrSet<Instruction *, 4> GuardedInstructions;
2861 
2862   SmallPtrSetImpl<Instruction *> &getGuardedInstructions() {
2863     return GuardedInstructions;
2864   }
2865 
2866   /// See AbstractAttribute::initialize(...).
2867   void initialize(Attributor &A) override {
2868     // This is a high-level transform that might change the constant arguments
2869     // of the init and dinit calls. We need to tell the Attributor about this
2870     // to avoid other parts using the current constant value for simpliication.
2871     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
2872 
2873     Function *Fn = getAnchorScope();
2874     if (!OMPInfoCache.Kernels.count(Fn))
2875       return;
2876 
2877     // Add itself to the reaching kernel and set IsKernelEntry.
2878     ReachingKernelEntries.insert(Fn);
2879     IsKernelEntry = true;
2880 
2881     OMPInformationCache::RuntimeFunctionInfo &InitRFI =
2882         OMPInfoCache.RFIs[OMPRTL___kmpc_target_init];
2883     OMPInformationCache::RuntimeFunctionInfo &DeinitRFI =
2884         OMPInfoCache.RFIs[OMPRTL___kmpc_target_deinit];
2885 
2886     // For kernels we perform more initialization work, first we find the init
2887     // and deinit calls.
2888     auto StoreCallBase = [](Use &U,
2889                             OMPInformationCache::RuntimeFunctionInfo &RFI,
2890                             CallBase *&Storage) {
2891       CallBase *CB = OpenMPOpt::getCallIfRegularCall(U, &RFI);
2892       assert(CB &&
2893              "Unexpected use of __kmpc_target_init or __kmpc_target_deinit!");
2894       assert(!Storage &&
2895              "Multiple uses of __kmpc_target_init or __kmpc_target_deinit!");
2896       Storage = CB;
2897       return false;
2898     };
2899     InitRFI.foreachUse(
2900         [&](Use &U, Function &) {
2901           StoreCallBase(U, InitRFI, KernelInitCB);
2902           return false;
2903         },
2904         Fn);
2905     DeinitRFI.foreachUse(
2906         [&](Use &U, Function &) {
2907           StoreCallBase(U, DeinitRFI, KernelDeinitCB);
2908           return false;
2909         },
2910         Fn);
2911 
2912     // Ignore kernels without initializers such as global constructors.
2913     if (!KernelInitCB || !KernelDeinitCB) {
2914       indicateOptimisticFixpoint();
2915       return;
2916     }
2917 
2918     // For kernels we might need to initialize/finalize the IsSPMD state and
2919     // we need to register a simplification callback so that the Attributor
2920     // knows the constant arguments to __kmpc_target_init and
2921     // __kmpc_target_deinit might actually change.
2922 
2923     Attributor::SimplifictionCallbackTy StateMachineSimplifyCB =
2924         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2925             bool &UsedAssumedInformation) -> Optional<Value *> {
2926       // IRP represents the "use generic state machine" argument of an
2927       // __kmpc_target_init call. We will answer this one with the internal
2928       // state. As long as we are not in an invalid state, we will create a
2929       // custom state machine so the value should be a `i1 false`. If we are
2930       // in an invalid state, we won't change the value that is in the IR.
2931       if (!isValidState())
2932         return nullptr;
2933       // If we have disabled state machine rewrites, don't make a custom one.
2934       if (DisableOpenMPOptStateMachineRewrite)
2935         return nullptr;
2936       if (AA)
2937         A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2938       UsedAssumedInformation = !isAtFixpoint();
2939       auto *FalseVal =
2940           ConstantInt::getBool(IRP.getAnchorValue().getContext(), 0);
2941       return FalseVal;
2942     };
2943 
2944     Attributor::SimplifictionCallbackTy IsSPMDModeSimplifyCB =
2945         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2946             bool &UsedAssumedInformation) -> Optional<Value *> {
2947       // IRP represents the "SPMDCompatibilityTracker" argument of an
2948       // __kmpc_target_init or
2949       // __kmpc_target_deinit call. We will answer this one with the internal
2950       // state.
2951       if (!SPMDCompatibilityTracker.isValidState())
2952         return nullptr;
2953       if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2954         if (AA)
2955           A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2956         UsedAssumedInformation = true;
2957       } else {
2958         UsedAssumedInformation = false;
2959       }
2960       auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
2961                                        SPMDCompatibilityTracker.isAssumed());
2962       return Val;
2963     };
2964 
2965     Attributor::SimplifictionCallbackTy IsGenericModeSimplifyCB =
2966         [&](const IRPosition &IRP, const AbstractAttribute *AA,
2967             bool &UsedAssumedInformation) -> Optional<Value *> {
2968       // IRP represents the "RequiresFullRuntime" argument of an
2969       // __kmpc_target_init or __kmpc_target_deinit call. We will answer this
2970       // one with the internal state of the SPMDCompatibilityTracker, so if
2971       // generic then true, if SPMD then false.
2972       if (!SPMDCompatibilityTracker.isValidState())
2973         return nullptr;
2974       if (!SPMDCompatibilityTracker.isAtFixpoint()) {
2975         if (AA)
2976           A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
2977         UsedAssumedInformation = true;
2978       } else {
2979         UsedAssumedInformation = false;
2980       }
2981       auto *Val = ConstantInt::getBool(IRP.getAnchorValue().getContext(),
2982                                        !SPMDCompatibilityTracker.isAssumed());
2983       return Val;
2984     };
2985 
2986     constexpr const int InitIsSPMDArgNo = 1;
2987     constexpr const int DeinitIsSPMDArgNo = 1;
2988     constexpr const int InitUseStateMachineArgNo = 2;
2989     constexpr const int InitRequiresFullRuntimeArgNo = 3;
2990     constexpr const int DeinitRequiresFullRuntimeArgNo = 2;
2991     A.registerSimplificationCallback(
2992         IRPosition::callsite_argument(*KernelInitCB, InitUseStateMachineArgNo),
2993         StateMachineSimplifyCB);
2994     A.registerSimplificationCallback(
2995         IRPosition::callsite_argument(*KernelInitCB, InitIsSPMDArgNo),
2996         IsSPMDModeSimplifyCB);
2997     A.registerSimplificationCallback(
2998         IRPosition::callsite_argument(*KernelDeinitCB, DeinitIsSPMDArgNo),
2999         IsSPMDModeSimplifyCB);
3000     A.registerSimplificationCallback(
3001         IRPosition::callsite_argument(*KernelInitCB,
3002                                       InitRequiresFullRuntimeArgNo),
3003         IsGenericModeSimplifyCB);
3004     A.registerSimplificationCallback(
3005         IRPosition::callsite_argument(*KernelDeinitCB,
3006                                       DeinitRequiresFullRuntimeArgNo),
3007         IsGenericModeSimplifyCB);
3008 
3009     // Check if we know we are in SPMD-mode already.
3010     ConstantInt *IsSPMDArg =
3011         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo));
3012     if (IsSPMDArg && !IsSPMDArg->isZero())
3013       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3014     // This is a generic region but SPMDization is disabled so stop tracking.
3015     else if (DisableOpenMPOptSPMDization)
3016       SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3017   }
3018 
3019   /// Modify the IR based on the KernelInfoState as the fixpoint iteration is
3020   /// finished now.
3021   ChangeStatus manifest(Attributor &A) override {
3022     // If we are not looking at a kernel with __kmpc_target_init and
3023     // __kmpc_target_deinit call we cannot actually manifest the information.
3024     if (!KernelInitCB || !KernelDeinitCB)
3025       return ChangeStatus::UNCHANGED;
3026 
3027     // Known SPMD-mode kernels need no manifest changes.
3028     if (SPMDCompatibilityTracker.isKnown())
3029       return ChangeStatus::UNCHANGED;
3030 
3031     // If we can we change the execution mode to SPMD-mode otherwise we build a
3032     // custom state machine.
3033     if (!mayContainParallelRegion() || !changeToSPMDMode(A))
3034       buildCustomStateMachine(A);
3035 
3036     return ChangeStatus::CHANGED;
3037   }
3038 
3039   bool changeToSPMDMode(Attributor &A) {
3040     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3041 
3042     if (!SPMDCompatibilityTracker.isAssumed()) {
3043       for (Instruction *NonCompatibleI : SPMDCompatibilityTracker) {
3044         if (!NonCompatibleI)
3045           continue;
3046 
3047         // Skip diagnostics on calls to known OpenMP runtime functions for now.
3048         if (auto *CB = dyn_cast<CallBase>(NonCompatibleI))
3049           if (OMPInfoCache.RTLFunctions.contains(CB->getCalledFunction()))
3050             continue;
3051 
3052         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3053           ORA << "Value has potential side effects preventing SPMD-mode "
3054                  "execution";
3055           if (isa<CallBase>(NonCompatibleI)) {
3056             ORA << ". Add `__attribute__((assume(\"ompx_spmd_amenable\")))` to "
3057                    "the called function to override";
3058           }
3059           return ORA << ".";
3060         };
3061         A.emitRemark<OptimizationRemarkAnalysis>(NonCompatibleI, "OMP121",
3062                                                  Remark);
3063 
3064         LLVM_DEBUG(dbgs() << TAG << "SPMD-incompatible side-effect: "
3065                           << *NonCompatibleI << "\n");
3066       }
3067 
3068       return false;
3069     }
3070 
3071     auto CreateGuardedRegion = [&](Instruction *RegionStartI,
3072                                    Instruction *RegionEndI) {
3073       LoopInfo *LI = nullptr;
3074       DominatorTree *DT = nullptr;
3075       MemorySSAUpdater *MSU = nullptr;
3076       using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
3077 
3078       BasicBlock *ParentBB = RegionStartI->getParent();
3079       Function *Fn = ParentBB->getParent();
3080       Module &M = *Fn->getParent();
3081 
3082       // Create all the blocks and logic.
3083       // ParentBB:
3084       //    goto RegionCheckTidBB
3085       // RegionCheckTidBB:
3086       //    Tid = __kmpc_hardware_thread_id()
3087       //    if (Tid != 0)
3088       //        goto RegionBarrierBB
3089       // RegionStartBB:
3090       //    <execute instructions guarded>
3091       //    goto RegionEndBB
3092       // RegionEndBB:
3093       //    <store escaping values to shared mem>
3094       //    goto RegionBarrierBB
3095       //  RegionBarrierBB:
3096       //    __kmpc_simple_barrier_spmd()
3097       //    // second barrier is omitted if lacking escaping values.
3098       //    <load escaping values from shared mem>
3099       //    __kmpc_simple_barrier_spmd()
3100       //    goto RegionExitBB
3101       // RegionExitBB:
3102       //    <execute rest of instructions>
3103 
3104       BasicBlock *RegionEndBB = SplitBlock(ParentBB, RegionEndI->getNextNode(),
3105                                            DT, LI, MSU, "region.guarded.end");
3106       BasicBlock *RegionBarrierBB =
3107           SplitBlock(RegionEndBB, &*RegionEndBB->getFirstInsertionPt(), DT, LI,
3108                      MSU, "region.barrier");
3109       BasicBlock *RegionExitBB =
3110           SplitBlock(RegionBarrierBB, &*RegionBarrierBB->getFirstInsertionPt(),
3111                      DT, LI, MSU, "region.exit");
3112       BasicBlock *RegionStartBB =
3113           SplitBlock(ParentBB, RegionStartI, DT, LI, MSU, "region.guarded");
3114 
3115       assert(ParentBB->getUniqueSuccessor() == RegionStartBB &&
3116              "Expected a different CFG");
3117 
3118       BasicBlock *RegionCheckTidBB = SplitBlock(
3119           ParentBB, ParentBB->getTerminator(), DT, LI, MSU, "region.check.tid");
3120 
3121       // Register basic blocks with the Attributor.
3122       A.registerManifestAddedBasicBlock(*RegionEndBB);
3123       A.registerManifestAddedBasicBlock(*RegionBarrierBB);
3124       A.registerManifestAddedBasicBlock(*RegionExitBB);
3125       A.registerManifestAddedBasicBlock(*RegionStartBB);
3126       A.registerManifestAddedBasicBlock(*RegionCheckTidBB);
3127 
3128       bool HasBroadcastValues = false;
3129       // Find escaping outputs from the guarded region to outside users and
3130       // broadcast their values to them.
3131       for (Instruction &I : *RegionStartBB) {
3132         SmallPtrSet<Instruction *, 4> OutsideUsers;
3133         for (User *Usr : I.users()) {
3134           Instruction &UsrI = *cast<Instruction>(Usr);
3135           if (UsrI.getParent() != RegionStartBB)
3136             OutsideUsers.insert(&UsrI);
3137         }
3138 
3139         if (OutsideUsers.empty())
3140           continue;
3141 
3142         HasBroadcastValues = true;
3143 
3144         // Emit a global variable in shared memory to store the broadcasted
3145         // value.
3146         auto *SharedMem = new GlobalVariable(
3147             M, I.getType(), /* IsConstant */ false,
3148             GlobalValue::InternalLinkage, UndefValue::get(I.getType()),
3149             I.getName() + ".guarded.output.alloc", nullptr,
3150             GlobalValue::NotThreadLocal,
3151             static_cast<unsigned>(AddressSpace::Shared));
3152 
3153         // Emit a store instruction to update the value.
3154         new StoreInst(&I, SharedMem, RegionEndBB->getTerminator());
3155 
3156         LoadInst *LoadI = new LoadInst(I.getType(), SharedMem,
3157                                        I.getName() + ".guarded.output.load",
3158                                        RegionBarrierBB->getTerminator());
3159 
3160         // Emit a load instruction and replace uses of the output value.
3161         for (Instruction *UsrI : OutsideUsers) {
3162           assert(UsrI->getParent() == RegionExitBB &&
3163                  "Expected escaping users in exit region");
3164           UsrI->replaceUsesOfWith(&I, LoadI);
3165         }
3166       }
3167 
3168       auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3169 
3170       // Go to tid check BB in ParentBB.
3171       const DebugLoc DL = ParentBB->getTerminator()->getDebugLoc();
3172       ParentBB->getTerminator()->eraseFromParent();
3173       OpenMPIRBuilder::LocationDescription Loc(
3174           InsertPointTy(ParentBB, ParentBB->end()), DL);
3175       OMPInfoCache.OMPBuilder.updateToLocation(Loc);
3176       auto *SrcLocStr = OMPInfoCache.OMPBuilder.getOrCreateSrcLocStr(Loc);
3177       Value *Ident = OMPInfoCache.OMPBuilder.getOrCreateIdent(SrcLocStr);
3178       BranchInst::Create(RegionCheckTidBB, ParentBB)->setDebugLoc(DL);
3179 
3180       // Add check for Tid in RegionCheckTidBB
3181       RegionCheckTidBB->getTerminator()->eraseFromParent();
3182       OpenMPIRBuilder::LocationDescription LocRegionCheckTid(
3183           InsertPointTy(RegionCheckTidBB, RegionCheckTidBB->end()), DL);
3184       OMPInfoCache.OMPBuilder.updateToLocation(LocRegionCheckTid);
3185       FunctionCallee HardwareTidFn =
3186           OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3187               M, OMPRTL___kmpc_get_hardware_thread_id_in_block);
3188       Value *Tid =
3189           OMPInfoCache.OMPBuilder.Builder.CreateCall(HardwareTidFn, {});
3190       Value *TidCheck = OMPInfoCache.OMPBuilder.Builder.CreateIsNull(Tid);
3191       OMPInfoCache.OMPBuilder.Builder
3192           .CreateCondBr(TidCheck, RegionStartBB, RegionBarrierBB)
3193           ->setDebugLoc(DL);
3194 
3195       // First barrier for synchronization, ensures main thread has updated
3196       // values.
3197       FunctionCallee BarrierFn =
3198           OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3199               M, OMPRTL___kmpc_barrier_simple_spmd);
3200       OMPInfoCache.OMPBuilder.updateToLocation(InsertPointTy(
3201           RegionBarrierBB, RegionBarrierBB->getFirstInsertionPt()));
3202       OMPInfoCache.OMPBuilder.Builder.CreateCall(BarrierFn, {Ident, Tid})
3203           ->setDebugLoc(DL);
3204 
3205       // Second barrier ensures workers have read broadcast values.
3206       if (HasBroadcastValues)
3207         CallInst::Create(BarrierFn, {Ident, Tid}, "",
3208                          RegionBarrierBB->getTerminator())
3209             ->setDebugLoc(DL);
3210     };
3211 
3212     auto &AllocSharedRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
3213     SmallPtrSet<BasicBlock *, 8> Visited;
3214     for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3215       BasicBlock *BB = GuardedI->getParent();
3216       if (!Visited.insert(BB).second)
3217         continue;
3218 
3219       SmallVector<std::pair<Instruction *, Instruction *>> Reorders;
3220       Instruction *LastEffect = nullptr;
3221       BasicBlock::reverse_iterator IP = BB->rbegin(), IPEnd = BB->rend();
3222       while (++IP != IPEnd) {
3223         if (!IP->mayHaveSideEffects() && !IP->mayReadFromMemory())
3224           continue;
3225         Instruction *I = &*IP;
3226         if (OpenMPOpt::getCallIfRegularCall(*I, &AllocSharedRFI))
3227           continue;
3228         if (!I->user_empty() || !SPMDCompatibilityTracker.contains(I)) {
3229           LastEffect = nullptr;
3230           continue;
3231         }
3232         if (LastEffect)
3233           Reorders.push_back({I, LastEffect});
3234         LastEffect = &*IP;
3235       }
3236       for (auto &Reorder : Reorders)
3237         Reorder.first->moveBefore(Reorder.second);
3238     }
3239 
3240     SmallVector<std::pair<Instruction *, Instruction *>, 4> GuardedRegions;
3241 
3242     for (Instruction *GuardedI : SPMDCompatibilityTracker) {
3243       BasicBlock *BB = GuardedI->getParent();
3244       auto *CalleeAA = A.lookupAAFor<AAKernelInfo>(
3245           IRPosition::function(*GuardedI->getFunction()), nullptr,
3246           DepClassTy::NONE);
3247       assert(CalleeAA != nullptr && "Expected Callee AAKernelInfo");
3248       auto &CalleeAAFunction = *cast<AAKernelInfoFunction>(CalleeAA);
3249       // Continue if instruction is already guarded.
3250       if (CalleeAAFunction.getGuardedInstructions().contains(GuardedI))
3251         continue;
3252 
3253       Instruction *GuardedRegionStart = nullptr, *GuardedRegionEnd = nullptr;
3254       for (Instruction &I : *BB) {
3255         // If instruction I needs to be guarded update the guarded region
3256         // bounds.
3257         if (SPMDCompatibilityTracker.contains(&I)) {
3258           CalleeAAFunction.getGuardedInstructions().insert(&I);
3259           if (GuardedRegionStart)
3260             GuardedRegionEnd = &I;
3261           else
3262             GuardedRegionStart = GuardedRegionEnd = &I;
3263 
3264           continue;
3265         }
3266 
3267         // Instruction I does not need guarding, store
3268         // any region found and reset bounds.
3269         if (GuardedRegionStart) {
3270           GuardedRegions.push_back(
3271               std::make_pair(GuardedRegionStart, GuardedRegionEnd));
3272           GuardedRegionStart = nullptr;
3273           GuardedRegionEnd = nullptr;
3274         }
3275       }
3276     }
3277 
3278     for (auto &GR : GuardedRegions)
3279       CreateGuardedRegion(GR.first, GR.second);
3280 
3281     // Adjust the global exec mode flag that tells the runtime what mode this
3282     // kernel is executed in.
3283     Function *Kernel = getAnchorScope();
3284     GlobalVariable *ExecMode = Kernel->getParent()->getGlobalVariable(
3285         (Kernel->getName() + "_exec_mode").str());
3286     assert(ExecMode && "Kernel without exec mode?");
3287     assert(ExecMode->getInitializer() && "ExecMode doesn't have initializer!");
3288 
3289     // Set the global exec mode flag to indicate SPMD-Generic mode.
3290     assert(isa<ConstantInt>(ExecMode->getInitializer()) &&
3291            "ExecMode is not an integer!");
3292     const int8_t ExecModeVal =
3293         cast<ConstantInt>(ExecMode->getInitializer())->getSExtValue();
3294     assert(ExecModeVal == OMP_TGT_EXEC_MODE_GENERIC &&
3295            "Initially non-SPMD kernel has SPMD exec mode!");
3296     ExecMode->setInitializer(
3297         ConstantInt::get(ExecMode->getInitializer()->getType(),
3298                          ExecModeVal | OMP_TGT_EXEC_MODE_GENERIC_SPMD));
3299 
3300     // Next rewrite the init and deinit calls to indicate we use SPMD-mode now.
3301     const int InitIsSPMDArgNo = 1;
3302     const int DeinitIsSPMDArgNo = 1;
3303     const int InitUseStateMachineArgNo = 2;
3304     const int InitRequiresFullRuntimeArgNo = 3;
3305     const int DeinitRequiresFullRuntimeArgNo = 2;
3306 
3307     auto &Ctx = getAnchorValue().getContext();
3308     A.changeUseAfterManifest(KernelInitCB->getArgOperandUse(InitIsSPMDArgNo),
3309                              *ConstantInt::getBool(Ctx, 1));
3310     A.changeUseAfterManifest(
3311         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo),
3312         *ConstantInt::getBool(Ctx, 0));
3313     A.changeUseAfterManifest(
3314         KernelDeinitCB->getArgOperandUse(DeinitIsSPMDArgNo),
3315         *ConstantInt::getBool(Ctx, 1));
3316     A.changeUseAfterManifest(
3317         KernelInitCB->getArgOperandUse(InitRequiresFullRuntimeArgNo),
3318         *ConstantInt::getBool(Ctx, 0));
3319     A.changeUseAfterManifest(
3320         KernelDeinitCB->getArgOperandUse(DeinitRequiresFullRuntimeArgNo),
3321         *ConstantInt::getBool(Ctx, 0));
3322 
3323     ++NumOpenMPTargetRegionKernelsSPMD;
3324 
3325     auto Remark = [&](OptimizationRemark OR) {
3326       return OR << "Transformed generic-mode kernel to SPMD-mode.";
3327     };
3328     A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP120", Remark);
3329     return true;
3330   };
3331 
3332   ChangeStatus buildCustomStateMachine(Attributor &A) {
3333     // If we have disabled state machine rewrites, don't make a custom one
3334     if (DisableOpenMPOptStateMachineRewrite)
3335       return indicatePessimisticFixpoint();
3336 
3337     assert(ReachedKnownParallelRegions.isValidState() &&
3338            "Custom state machine with invalid parallel region states?");
3339 
3340     const int InitIsSPMDArgNo = 1;
3341     const int InitUseStateMachineArgNo = 2;
3342 
3343     // Check if the current configuration is non-SPMD and generic state machine.
3344     // If we already have SPMD mode or a custom state machine we do not need to
3345     // go any further. If it is anything but a constant something is weird and
3346     // we give up.
3347     ConstantInt *UseStateMachine = dyn_cast<ConstantInt>(
3348         KernelInitCB->getArgOperand(InitUseStateMachineArgNo));
3349     ConstantInt *IsSPMD =
3350         dyn_cast<ConstantInt>(KernelInitCB->getArgOperand(InitIsSPMDArgNo));
3351 
3352     // If we are stuck with generic mode, try to create a custom device (=GPU)
3353     // state machine which is specialized for the parallel regions that are
3354     // reachable by the kernel.
3355     if (!UseStateMachine || UseStateMachine->isZero() || !IsSPMD ||
3356         !IsSPMD->isZero())
3357       return ChangeStatus::UNCHANGED;
3358 
3359     // If not SPMD mode, indicate we use a custom state machine now.
3360     auto &Ctx = getAnchorValue().getContext();
3361     auto *FalseVal = ConstantInt::getBool(Ctx, 0);
3362     A.changeUseAfterManifest(
3363         KernelInitCB->getArgOperandUse(InitUseStateMachineArgNo), *FalseVal);
3364 
3365     // If we don't actually need a state machine we are done here. This can
3366     // happen if there simply are no parallel regions. In the resulting kernel
3367     // all worker threads will simply exit right away, leaving the main thread
3368     // to do the work alone.
3369     if (!mayContainParallelRegion()) {
3370       ++NumOpenMPTargetRegionKernelsWithoutStateMachine;
3371 
3372       auto Remark = [&](OptimizationRemark OR) {
3373         return OR << "Removing unused state machine from generic-mode kernel.";
3374       };
3375       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP130", Remark);
3376 
3377       return ChangeStatus::CHANGED;
3378     }
3379 
3380     // Keep track in the statistics of our new shiny custom state machine.
3381     if (ReachedUnknownParallelRegions.empty()) {
3382       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithoutFallback;
3383 
3384       auto Remark = [&](OptimizationRemark OR) {
3385         return OR << "Rewriting generic-mode kernel with a customized state "
3386                      "machine.";
3387       };
3388       A.emitRemark<OptimizationRemark>(KernelInitCB, "OMP131", Remark);
3389     } else {
3390       ++NumOpenMPTargetRegionKernelsCustomStateMachineWithFallback;
3391 
3392       auto Remark = [&](OptimizationRemarkAnalysis OR) {
3393         return OR << "Generic-mode kernel is executed with a customized state "
3394                      "machine that requires a fallback.";
3395       };
3396       A.emitRemark<OptimizationRemarkAnalysis>(KernelInitCB, "OMP132", Remark);
3397 
3398       // Tell the user why we ended up with a fallback.
3399       for (CallBase *UnknownParallelRegionCB : ReachedUnknownParallelRegions) {
3400         if (!UnknownParallelRegionCB)
3401           continue;
3402         auto Remark = [&](OptimizationRemarkAnalysis ORA) {
3403           return ORA << "Call may contain unknown parallel regions. Use "
3404                      << "`__attribute__((assume(\"omp_no_parallelism\")))` to "
3405                         "override.";
3406         };
3407         A.emitRemark<OptimizationRemarkAnalysis>(UnknownParallelRegionCB,
3408                                                  "OMP133", Remark);
3409       }
3410     }
3411 
3412     // Create all the blocks:
3413     //
3414     //                       InitCB = __kmpc_target_init(...)
3415     //                       bool IsWorker = InitCB >= 0;
3416     //                       if (IsWorker) {
3417     // SMBeginBB:               __kmpc_barrier_simple_spmd(...);
3418     //                         void *WorkFn;
3419     //                         bool Active = __kmpc_kernel_parallel(&WorkFn);
3420     //                         if (!WorkFn) return;
3421     // SMIsActiveCheckBB:       if (Active) {
3422     // SMIfCascadeCurrentBB:      if      (WorkFn == <ParFn0>)
3423     //                              ParFn0(...);
3424     // SMIfCascadeCurrentBB:      else if (WorkFn == <ParFn1>)
3425     //                              ParFn1(...);
3426     //                            ...
3427     // SMIfCascadeCurrentBB:      else
3428     //                              ((WorkFnTy*)WorkFn)(...);
3429     // SMEndParallelBB:           __kmpc_kernel_end_parallel(...);
3430     //                          }
3431     // SMDoneBB:                __kmpc_barrier_simple_spmd(...);
3432     //                          goto SMBeginBB;
3433     //                       }
3434     // UserCodeEntryBB:      // user code
3435     //                       __kmpc_target_deinit(...)
3436     //
3437     Function *Kernel = getAssociatedFunction();
3438     assert(Kernel && "Expected an associated function!");
3439 
3440     BasicBlock *InitBB = KernelInitCB->getParent();
3441     BasicBlock *UserCodeEntryBB = InitBB->splitBasicBlock(
3442         KernelInitCB->getNextNode(), "thread.user_code.check");
3443     BasicBlock *StateMachineBeginBB = BasicBlock::Create(
3444         Ctx, "worker_state_machine.begin", Kernel, UserCodeEntryBB);
3445     BasicBlock *StateMachineFinishedBB = BasicBlock::Create(
3446         Ctx, "worker_state_machine.finished", Kernel, UserCodeEntryBB);
3447     BasicBlock *StateMachineIsActiveCheckBB = BasicBlock::Create(
3448         Ctx, "worker_state_machine.is_active.check", Kernel, UserCodeEntryBB);
3449     BasicBlock *StateMachineIfCascadeCurrentBB =
3450         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3451                            Kernel, UserCodeEntryBB);
3452     BasicBlock *StateMachineEndParallelBB =
3453         BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.end",
3454                            Kernel, UserCodeEntryBB);
3455     BasicBlock *StateMachineDoneBarrierBB = BasicBlock::Create(
3456         Ctx, "worker_state_machine.done.barrier", Kernel, UserCodeEntryBB);
3457     A.registerManifestAddedBasicBlock(*InitBB);
3458     A.registerManifestAddedBasicBlock(*UserCodeEntryBB);
3459     A.registerManifestAddedBasicBlock(*StateMachineBeginBB);
3460     A.registerManifestAddedBasicBlock(*StateMachineFinishedBB);
3461     A.registerManifestAddedBasicBlock(*StateMachineIsActiveCheckBB);
3462     A.registerManifestAddedBasicBlock(*StateMachineIfCascadeCurrentBB);
3463     A.registerManifestAddedBasicBlock(*StateMachineEndParallelBB);
3464     A.registerManifestAddedBasicBlock(*StateMachineDoneBarrierBB);
3465 
3466     const DebugLoc &DLoc = KernelInitCB->getDebugLoc();
3467     ReturnInst::Create(Ctx, StateMachineFinishedBB)->setDebugLoc(DLoc);
3468 
3469     InitBB->getTerminator()->eraseFromParent();
3470     Instruction *IsWorker =
3471         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_NE, KernelInitCB,
3472                          ConstantInt::get(KernelInitCB->getType(), -1),
3473                          "thread.is_worker", InitBB);
3474     IsWorker->setDebugLoc(DLoc);
3475     BranchInst::Create(StateMachineBeginBB, UserCodeEntryBB, IsWorker, InitBB);
3476 
3477     Module &M = *Kernel->getParent();
3478 
3479     // Create local storage for the work function pointer.
3480     const DataLayout &DL = M.getDataLayout();
3481     Type *VoidPtrTy = Type::getInt8PtrTy(Ctx);
3482     Instruction *WorkFnAI =
3483         new AllocaInst(VoidPtrTy, DL.getAllocaAddrSpace(), nullptr,
3484                        "worker.work_fn.addr", &Kernel->getEntryBlock().front());
3485     WorkFnAI->setDebugLoc(DLoc);
3486 
3487     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3488     OMPInfoCache.OMPBuilder.updateToLocation(
3489         OpenMPIRBuilder::LocationDescription(
3490             IRBuilder<>::InsertPoint(StateMachineBeginBB,
3491                                      StateMachineBeginBB->end()),
3492             DLoc));
3493 
3494     Value *Ident = KernelInitCB->getArgOperand(0);
3495     Value *GTid = KernelInitCB;
3496 
3497     FunctionCallee BarrierFn =
3498         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3499             M, OMPRTL___kmpc_barrier_simple_spmd);
3500     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineBeginBB)
3501         ->setDebugLoc(DLoc);
3502 
3503     if (WorkFnAI->getType()->getPointerAddressSpace() !=
3504         (unsigned int)AddressSpace::Generic) {
3505       WorkFnAI = new AddrSpaceCastInst(
3506           WorkFnAI,
3507           PointerType::getWithSamePointeeType(
3508               cast<PointerType>(WorkFnAI->getType()),
3509               (unsigned int)AddressSpace::Generic),
3510           WorkFnAI->getName() + ".generic", StateMachineBeginBB);
3511       WorkFnAI->setDebugLoc(DLoc);
3512     }
3513 
3514     FunctionCallee KernelParallelFn =
3515         OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3516             M, OMPRTL___kmpc_kernel_parallel);
3517     Instruction *IsActiveWorker = CallInst::Create(
3518         KernelParallelFn, {WorkFnAI}, "worker.is_active", StateMachineBeginBB);
3519     IsActiveWorker->setDebugLoc(DLoc);
3520     Instruction *WorkFn = new LoadInst(VoidPtrTy, WorkFnAI, "worker.work_fn",
3521                                        StateMachineBeginBB);
3522     WorkFn->setDebugLoc(DLoc);
3523 
3524     FunctionType *ParallelRegionFnTy = FunctionType::get(
3525         Type::getVoidTy(Ctx), {Type::getInt16Ty(Ctx), Type::getInt32Ty(Ctx)},
3526         false);
3527     Value *WorkFnCast = BitCastInst::CreatePointerBitCastOrAddrSpaceCast(
3528         WorkFn, ParallelRegionFnTy->getPointerTo(), "worker.work_fn.addr_cast",
3529         StateMachineBeginBB);
3530 
3531     Instruction *IsDone =
3532         ICmpInst::Create(ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFn,
3533                          Constant::getNullValue(VoidPtrTy), "worker.is_done",
3534                          StateMachineBeginBB);
3535     IsDone->setDebugLoc(DLoc);
3536     BranchInst::Create(StateMachineFinishedBB, StateMachineIsActiveCheckBB,
3537                        IsDone, StateMachineBeginBB)
3538         ->setDebugLoc(DLoc);
3539 
3540     BranchInst::Create(StateMachineIfCascadeCurrentBB,
3541                        StateMachineDoneBarrierBB, IsActiveWorker,
3542                        StateMachineIsActiveCheckBB)
3543         ->setDebugLoc(DLoc);
3544 
3545     Value *ZeroArg =
3546         Constant::getNullValue(ParallelRegionFnTy->getParamType(0));
3547 
3548     // Now that we have most of the CFG skeleton it is time for the if-cascade
3549     // that checks the function pointer we got from the runtime against the
3550     // parallel regions we expect, if there are any.
3551     for (int i = 0, e = ReachedKnownParallelRegions.size(); i < e; ++i) {
3552       auto *ParallelRegion = ReachedKnownParallelRegions[i];
3553       BasicBlock *PRExecuteBB = BasicBlock::Create(
3554           Ctx, "worker_state_machine.parallel_region.execute", Kernel,
3555           StateMachineEndParallelBB);
3556       CallInst::Create(ParallelRegion, {ZeroArg, GTid}, "", PRExecuteBB)
3557           ->setDebugLoc(DLoc);
3558       BranchInst::Create(StateMachineEndParallelBB, PRExecuteBB)
3559           ->setDebugLoc(DLoc);
3560 
3561       BasicBlock *PRNextBB =
3562           BasicBlock::Create(Ctx, "worker_state_machine.parallel_region.check",
3563                              Kernel, StateMachineEndParallelBB);
3564 
3565       // Check if we need to compare the pointer at all or if we can just
3566       // call the parallel region function.
3567       Value *IsPR;
3568       if (i + 1 < e || !ReachedUnknownParallelRegions.empty()) {
3569         Instruction *CmpI = ICmpInst::Create(
3570             ICmpInst::ICmp, llvm::CmpInst::ICMP_EQ, WorkFnCast, ParallelRegion,
3571             "worker.check_parallel_region", StateMachineIfCascadeCurrentBB);
3572         CmpI->setDebugLoc(DLoc);
3573         IsPR = CmpI;
3574       } else {
3575         IsPR = ConstantInt::getTrue(Ctx);
3576       }
3577 
3578       BranchInst::Create(PRExecuteBB, PRNextBB, IsPR,
3579                          StateMachineIfCascadeCurrentBB)
3580           ->setDebugLoc(DLoc);
3581       StateMachineIfCascadeCurrentBB = PRNextBB;
3582     }
3583 
3584     // At the end of the if-cascade we place the indirect function pointer call
3585     // in case we might need it, that is if there can be parallel regions we
3586     // have not handled in the if-cascade above.
3587     if (!ReachedUnknownParallelRegions.empty()) {
3588       StateMachineIfCascadeCurrentBB->setName(
3589           "worker_state_machine.parallel_region.fallback.execute");
3590       CallInst::Create(ParallelRegionFnTy, WorkFnCast, {ZeroArg, GTid}, "",
3591                        StateMachineIfCascadeCurrentBB)
3592           ->setDebugLoc(DLoc);
3593     }
3594     BranchInst::Create(StateMachineEndParallelBB,
3595                        StateMachineIfCascadeCurrentBB)
3596         ->setDebugLoc(DLoc);
3597 
3598     CallInst::Create(OMPInfoCache.OMPBuilder.getOrCreateRuntimeFunction(
3599                          M, OMPRTL___kmpc_kernel_end_parallel),
3600                      {}, "", StateMachineEndParallelBB)
3601         ->setDebugLoc(DLoc);
3602     BranchInst::Create(StateMachineDoneBarrierBB, StateMachineEndParallelBB)
3603         ->setDebugLoc(DLoc);
3604 
3605     CallInst::Create(BarrierFn, {Ident, GTid}, "", StateMachineDoneBarrierBB)
3606         ->setDebugLoc(DLoc);
3607     BranchInst::Create(StateMachineBeginBB, StateMachineDoneBarrierBB)
3608         ->setDebugLoc(DLoc);
3609 
3610     return ChangeStatus::CHANGED;
3611   }
3612 
3613   /// Fixpoint iteration update function. Will be called every time a dependence
3614   /// changed its state (and in the beginning).
3615   ChangeStatus updateImpl(Attributor &A) override {
3616     KernelInfoState StateBefore = getState();
3617 
3618     // Callback to check a read/write instruction.
3619     auto CheckRWInst = [&](Instruction &I) {
3620       // We handle calls later.
3621       if (isa<CallBase>(I))
3622         return true;
3623       // We only care about write effects.
3624       if (!I.mayWriteToMemory())
3625         return true;
3626       if (auto *SI = dyn_cast<StoreInst>(&I)) {
3627         SmallVector<const Value *> Objects;
3628         getUnderlyingObjects(SI->getPointerOperand(), Objects);
3629         if (llvm::all_of(Objects,
3630                          [](const Value *Obj) { return isa<AllocaInst>(Obj); }))
3631           return true;
3632         // Check for AAHeapToStack moved objects which must not be guarded.
3633         auto &HS = A.getAAFor<AAHeapToStack>(
3634             *this, IRPosition::function(*I.getFunction()),
3635             DepClassTy::REQUIRED);
3636         if (llvm::all_of(Objects, [&HS](const Value *Obj) {
3637               auto *CB = dyn_cast<CallBase>(Obj);
3638               if (!CB)
3639                 return false;
3640               return HS.isAssumedHeapToStack(*CB);
3641             })) {
3642           return true;
3643         }
3644       }
3645 
3646       // Insert instruction that needs guarding.
3647       SPMDCompatibilityTracker.insert(&I);
3648       return true;
3649     };
3650 
3651     bool UsedAssumedInformationInCheckRWInst = false;
3652     if (!SPMDCompatibilityTracker.isAtFixpoint())
3653       if (!A.checkForAllReadWriteInstructions(
3654               CheckRWInst, *this, UsedAssumedInformationInCheckRWInst))
3655         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3656 
3657     if (!IsKernelEntry) {
3658       updateReachingKernelEntries(A);
3659       updateParallelLevels(A);
3660 
3661       if (!ParallelLevels.isValidState())
3662         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3663     }
3664 
3665     // Callback to check a call instruction.
3666     bool AllSPMDStatesWereFixed = true;
3667     auto CheckCallInst = [&](Instruction &I) {
3668       auto &CB = cast<CallBase>(I);
3669       auto &CBAA = A.getAAFor<AAKernelInfo>(
3670           *this, IRPosition::callsite_function(CB), DepClassTy::OPTIONAL);
3671       getState() ^= CBAA.getState();
3672       AllSPMDStatesWereFixed &= CBAA.SPMDCompatibilityTracker.isAtFixpoint();
3673       return true;
3674     };
3675 
3676     bool UsedAssumedInformationInCheckCallInst = false;
3677     if (!A.checkForAllCallLikeInstructions(
3678             CheckCallInst, *this, UsedAssumedInformationInCheckCallInst))
3679       return indicatePessimisticFixpoint();
3680 
3681     // If we haven't used any assumed information for the SPMD state we can fix
3682     // it.
3683     if (!UsedAssumedInformationInCheckRWInst &&
3684         !UsedAssumedInformationInCheckCallInst && AllSPMDStatesWereFixed)
3685       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3686 
3687     return StateBefore == getState() ? ChangeStatus::UNCHANGED
3688                                      : ChangeStatus::CHANGED;
3689   }
3690 
3691 private:
3692   /// Update info regarding reaching kernels.
3693   void updateReachingKernelEntries(Attributor &A) {
3694     auto PredCallSite = [&](AbstractCallSite ACS) {
3695       Function *Caller = ACS.getInstruction()->getFunction();
3696 
3697       assert(Caller && "Caller is nullptr");
3698 
3699       auto &CAA = A.getOrCreateAAFor<AAKernelInfo>(
3700           IRPosition::function(*Caller), this, DepClassTy::REQUIRED);
3701       if (CAA.ReachingKernelEntries.isValidState()) {
3702         ReachingKernelEntries ^= CAA.ReachingKernelEntries;
3703         return true;
3704       }
3705 
3706       // We lost track of the caller of the associated function, any kernel
3707       // could reach now.
3708       ReachingKernelEntries.indicatePessimisticFixpoint();
3709 
3710       return true;
3711     };
3712 
3713     bool AllCallSitesKnown;
3714     if (!A.checkForAllCallSites(PredCallSite, *this,
3715                                 true /* RequireAllCallSites */,
3716                                 AllCallSitesKnown))
3717       ReachingKernelEntries.indicatePessimisticFixpoint();
3718   }
3719 
3720   /// Update info regarding parallel levels.
3721   void updateParallelLevels(Attributor &A) {
3722     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3723     OMPInformationCache::RuntimeFunctionInfo &Parallel51RFI =
3724         OMPInfoCache.RFIs[OMPRTL___kmpc_parallel_51];
3725 
3726     auto PredCallSite = [&](AbstractCallSite ACS) {
3727       Function *Caller = ACS.getInstruction()->getFunction();
3728 
3729       assert(Caller && "Caller is nullptr");
3730 
3731       auto &CAA =
3732           A.getOrCreateAAFor<AAKernelInfo>(IRPosition::function(*Caller));
3733       if (CAA.ParallelLevels.isValidState()) {
3734         // Any function that is called by `__kmpc_parallel_51` will not be
3735         // folded as the parallel level in the function is updated. In order to
3736         // get it right, all the analysis would depend on the implentation. That
3737         // said, if in the future any change to the implementation, the analysis
3738         // could be wrong. As a consequence, we are just conservative here.
3739         if (Caller == Parallel51RFI.Declaration) {
3740           ParallelLevels.indicatePessimisticFixpoint();
3741           return true;
3742         }
3743 
3744         ParallelLevels ^= CAA.ParallelLevels;
3745 
3746         return true;
3747       }
3748 
3749       // We lost track of the caller of the associated function, any kernel
3750       // could reach now.
3751       ParallelLevels.indicatePessimisticFixpoint();
3752 
3753       return true;
3754     };
3755 
3756     bool AllCallSitesKnown = true;
3757     if (!A.checkForAllCallSites(PredCallSite, *this,
3758                                 true /* RequireAllCallSites */,
3759                                 AllCallSitesKnown))
3760       ParallelLevels.indicatePessimisticFixpoint();
3761   }
3762 };
3763 
3764 /// The call site kernel info abstract attribute, basically, what can we say
3765 /// about a call site with regards to the KernelInfoState. For now this simply
3766 /// forwards the information from the callee.
3767 struct AAKernelInfoCallSite : AAKernelInfo {
3768   AAKernelInfoCallSite(const IRPosition &IRP, Attributor &A)
3769       : AAKernelInfo(IRP, A) {}
3770 
3771   /// See AbstractAttribute::initialize(...).
3772   void initialize(Attributor &A) override {
3773     AAKernelInfo::initialize(A);
3774 
3775     CallBase &CB = cast<CallBase>(getAssociatedValue());
3776     Function *Callee = getAssociatedFunction();
3777 
3778     // Helper to lookup an assumption string.
3779     auto HasAssumption = [](CallBase &CB, StringRef AssumptionStr) {
3780       return hasAssumption(CB, AssumptionStr);
3781     };
3782 
3783     // Check for SPMD-mode assumptions.
3784     if (HasAssumption(CB, "ompx_spmd_amenable")) {
3785       SPMDCompatibilityTracker.indicateOptimisticFixpoint();
3786       indicateOptimisticFixpoint();
3787     }
3788 
3789     // First weed out calls we do not care about, that is readonly/readnone
3790     // calls, intrinsics, and "no_openmp" calls. Neither of these can reach a
3791     // parallel region or anything else we are looking for.
3792     if (!CB.mayWriteToMemory() || isa<IntrinsicInst>(CB)) {
3793       indicateOptimisticFixpoint();
3794       return;
3795     }
3796 
3797     // Next we check if we know the callee. If it is a known OpenMP function
3798     // we will handle them explicitly in the switch below. If it is not, we
3799     // will use an AAKernelInfo object on the callee to gather information and
3800     // merge that into the current state. The latter happens in the updateImpl.
3801     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3802     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
3803     if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
3804       // Unknown caller or declarations are not analyzable, we give up.
3805       if (!Callee || !A.isFunctionIPOAmendable(*Callee)) {
3806 
3807         // Unknown callees might contain parallel regions, except if they have
3808         // an appropriate assumption attached.
3809         if (!(HasAssumption(CB, "omp_no_openmp") ||
3810               HasAssumption(CB, "omp_no_parallelism")))
3811           ReachedUnknownParallelRegions.insert(&CB);
3812 
3813         // If SPMDCompatibilityTracker is not fixed, we need to give up on the
3814         // idea we can run something unknown in SPMD-mode.
3815         if (!SPMDCompatibilityTracker.isAtFixpoint()) {
3816           SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3817           SPMDCompatibilityTracker.insert(&CB);
3818         }
3819 
3820         // We have updated the state for this unknown call properly, there won't
3821         // be any change so we indicate a fixpoint.
3822         indicateOptimisticFixpoint();
3823       }
3824       // If the callee is known and can be used in IPO, we will update the state
3825       // based on the callee state in updateImpl.
3826       return;
3827     }
3828 
3829     const unsigned int WrapperFunctionArgNo = 6;
3830     RuntimeFunction RF = It->getSecond();
3831     switch (RF) {
3832     // All the functions we know are compatible with SPMD mode.
3833     case OMPRTL___kmpc_is_spmd_exec_mode:
3834     case OMPRTL___kmpc_for_static_fini:
3835     case OMPRTL___kmpc_global_thread_num:
3836     case OMPRTL___kmpc_get_hardware_num_threads_in_block:
3837     case OMPRTL___kmpc_get_hardware_num_blocks:
3838     case OMPRTL___kmpc_single:
3839     case OMPRTL___kmpc_end_single:
3840     case OMPRTL___kmpc_master:
3841     case OMPRTL___kmpc_end_master:
3842     case OMPRTL___kmpc_barrier:
3843       break;
3844     case OMPRTL___kmpc_for_static_init_4:
3845     case OMPRTL___kmpc_for_static_init_4u:
3846     case OMPRTL___kmpc_for_static_init_8:
3847     case OMPRTL___kmpc_for_static_init_8u: {
3848       // Check the schedule and allow static schedule in SPMD mode.
3849       unsigned ScheduleArgOpNo = 2;
3850       auto *ScheduleTypeCI =
3851           dyn_cast<ConstantInt>(CB.getArgOperand(ScheduleArgOpNo));
3852       unsigned ScheduleTypeVal =
3853           ScheduleTypeCI ? ScheduleTypeCI->getZExtValue() : 0;
3854       switch (OMPScheduleType(ScheduleTypeVal)) {
3855       case OMPScheduleType::Static:
3856       case OMPScheduleType::StaticChunked:
3857       case OMPScheduleType::Distribute:
3858       case OMPScheduleType::DistributeChunked:
3859         break;
3860       default:
3861         SPMDCompatibilityTracker.indicatePessimisticFixpoint();
3862         SPMDCompatibilityTracker.insert(&CB);
3863         break;
3864       };
3865     } break;
3866     case OMPRTL___kmpc_target_init:
3867       KernelInitCB = &CB;
3868       break;
3869     case OMPRTL___kmpc_target_deinit:
3870       KernelDeinitCB = &CB;
3871       break;
3872     case OMPRTL___kmpc_parallel_51:
3873       if (auto *ParallelRegion = dyn_cast<Function>(
3874               CB.getArgOperand(WrapperFunctionArgNo)->stripPointerCasts())) {
3875         ReachedKnownParallelRegions.insert(ParallelRegion);
3876         break;
3877       }
3878       // The condition above should usually get the parallel region function
3879       // pointer and record it. In the off chance it doesn't we assume the
3880       // worst.
3881       ReachedUnknownParallelRegions.insert(&CB);
3882       break;
3883     case OMPRTL___kmpc_omp_task:
3884       // We do not look into tasks right now, just give up.
3885       SPMDCompatibilityTracker.insert(&CB);
3886       ReachedUnknownParallelRegions.insert(&CB);
3887       indicatePessimisticFixpoint();
3888       return;
3889     case OMPRTL___kmpc_alloc_shared:
3890     case OMPRTL___kmpc_free_shared:
3891       // Return without setting a fixpoint, to be resolved in updateImpl.
3892       return;
3893     default:
3894       // Unknown OpenMP runtime calls cannot be executed in SPMD-mode,
3895       // generally.
3896       SPMDCompatibilityTracker.insert(&CB);
3897       indicatePessimisticFixpoint();
3898       return;
3899     }
3900     // All other OpenMP runtime calls will not reach parallel regions so they
3901     // can be safely ignored for now. Since it is a known OpenMP runtime call we
3902     // have now modeled all effects and there is no need for any update.
3903     indicateOptimisticFixpoint();
3904   }
3905 
3906   ChangeStatus updateImpl(Attributor &A) override {
3907     // TODO: Once we have call site specific value information we can provide
3908     //       call site specific liveness information and then it makes
3909     //       sense to specialize attributes for call sites arguments instead of
3910     //       redirecting requests to the callee argument.
3911     Function *F = getAssociatedFunction();
3912 
3913     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
3914     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(F);
3915 
3916     // If F is not a runtime function, propagate the AAKernelInfo of the callee.
3917     if (It == OMPInfoCache.RuntimeFunctionIDMap.end()) {
3918       const IRPosition &FnPos = IRPosition::function(*F);
3919       auto &FnAA = A.getAAFor<AAKernelInfo>(*this, FnPos, DepClassTy::REQUIRED);
3920       if (getState() == FnAA.getState())
3921         return ChangeStatus::UNCHANGED;
3922       getState() = FnAA.getState();
3923       return ChangeStatus::CHANGED;
3924     }
3925 
3926     // F is a runtime function that allocates or frees memory, check
3927     // AAHeapToStack and AAHeapToShared.
3928     KernelInfoState StateBefore = getState();
3929     assert((It->getSecond() == OMPRTL___kmpc_alloc_shared ||
3930             It->getSecond() == OMPRTL___kmpc_free_shared) &&
3931            "Expected a __kmpc_alloc_shared or __kmpc_free_shared runtime call");
3932 
3933     CallBase &CB = cast<CallBase>(getAssociatedValue());
3934 
3935     auto &HeapToStackAA = A.getAAFor<AAHeapToStack>(
3936         *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
3937     auto &HeapToSharedAA = A.getAAFor<AAHeapToShared>(
3938         *this, IRPosition::function(*CB.getCaller()), DepClassTy::OPTIONAL);
3939 
3940     RuntimeFunction RF = It->getSecond();
3941 
3942     switch (RF) {
3943     // If neither HeapToStack nor HeapToShared assume the call is removed,
3944     // assume SPMD incompatibility.
3945     case OMPRTL___kmpc_alloc_shared:
3946       if (!HeapToStackAA.isAssumedHeapToStack(CB) &&
3947           !HeapToSharedAA.isAssumedHeapToShared(CB))
3948         SPMDCompatibilityTracker.insert(&CB);
3949       break;
3950     case OMPRTL___kmpc_free_shared:
3951       if (!HeapToStackAA.isAssumedHeapToStackRemovedFree(CB) &&
3952           !HeapToSharedAA.isAssumedHeapToSharedRemovedFree(CB))
3953         SPMDCompatibilityTracker.insert(&CB);
3954       break;
3955     default:
3956       SPMDCompatibilityTracker.insert(&CB);
3957     }
3958 
3959     return StateBefore == getState() ? ChangeStatus::UNCHANGED
3960                                      : ChangeStatus::CHANGED;
3961   }
3962 };
3963 
3964 struct AAFoldRuntimeCall
3965     : public StateWrapper<BooleanState, AbstractAttribute> {
3966   using Base = StateWrapper<BooleanState, AbstractAttribute>;
3967 
3968   AAFoldRuntimeCall(const IRPosition &IRP, Attributor &A) : Base(IRP) {}
3969 
3970   /// Statistics are tracked as part of manifest for now.
3971   void trackStatistics() const override {}
3972 
3973   /// Create an abstract attribute biew for the position \p IRP.
3974   static AAFoldRuntimeCall &createForPosition(const IRPosition &IRP,
3975                                               Attributor &A);
3976 
3977   /// See AbstractAttribute::getName()
3978   const std::string getName() const override { return "AAFoldRuntimeCall"; }
3979 
3980   /// See AbstractAttribute::getIdAddr()
3981   const char *getIdAddr() const override { return &ID; }
3982 
3983   /// This function should return true if the type of the \p AA is
3984   /// AAFoldRuntimeCall
3985   static bool classof(const AbstractAttribute *AA) {
3986     return (AA->getIdAddr() == &ID);
3987   }
3988 
3989   static const char ID;
3990 };
3991 
3992 struct AAFoldRuntimeCallCallSiteReturned : AAFoldRuntimeCall {
3993   AAFoldRuntimeCallCallSiteReturned(const IRPosition &IRP, Attributor &A)
3994       : AAFoldRuntimeCall(IRP, A) {}
3995 
3996   /// See AbstractAttribute::getAsStr()
3997   const std::string getAsStr() const override {
3998     if (!isValidState())
3999       return "<invalid>";
4000 
4001     std::string Str("simplified value: ");
4002 
4003     if (!SimplifiedValue.hasValue())
4004       return Str + std::string("none");
4005 
4006     if (!SimplifiedValue.getValue())
4007       return Str + std::string("nullptr");
4008 
4009     if (ConstantInt *CI = dyn_cast<ConstantInt>(SimplifiedValue.getValue()))
4010       return Str + std::to_string(CI->getSExtValue());
4011 
4012     return Str + std::string("unknown");
4013   }
4014 
4015   void initialize(Attributor &A) override {
4016     if (DisableOpenMPOptFolding)
4017       indicatePessimisticFixpoint();
4018 
4019     Function *Callee = getAssociatedFunction();
4020 
4021     auto &OMPInfoCache = static_cast<OMPInformationCache &>(A.getInfoCache());
4022     const auto &It = OMPInfoCache.RuntimeFunctionIDMap.find(Callee);
4023     assert(It != OMPInfoCache.RuntimeFunctionIDMap.end() &&
4024            "Expected a known OpenMP runtime function");
4025 
4026     RFKind = It->getSecond();
4027 
4028     CallBase &CB = cast<CallBase>(getAssociatedValue());
4029     A.registerSimplificationCallback(
4030         IRPosition::callsite_returned(CB),
4031         [&](const IRPosition &IRP, const AbstractAttribute *AA,
4032             bool &UsedAssumedInformation) -> Optional<Value *> {
4033           assert((isValidState() || (SimplifiedValue.hasValue() &&
4034                                      SimplifiedValue.getValue() == nullptr)) &&
4035                  "Unexpected invalid state!");
4036 
4037           if (!isAtFixpoint()) {
4038             UsedAssumedInformation = true;
4039             if (AA)
4040               A.recordDependence(*this, *AA, DepClassTy::OPTIONAL);
4041           }
4042           return SimplifiedValue;
4043         });
4044   }
4045 
4046   ChangeStatus updateImpl(Attributor &A) override {
4047     ChangeStatus Changed = ChangeStatus::UNCHANGED;
4048     switch (RFKind) {
4049     case OMPRTL___kmpc_is_spmd_exec_mode:
4050       Changed |= foldIsSPMDExecMode(A);
4051       break;
4052     case OMPRTL___kmpc_is_generic_main_thread_id:
4053       Changed |= foldIsGenericMainThread(A);
4054       break;
4055     case OMPRTL___kmpc_parallel_level:
4056       Changed |= foldParallelLevel(A);
4057       break;
4058     case OMPRTL___kmpc_get_hardware_num_threads_in_block:
4059       Changed = Changed | foldKernelFnAttribute(A, "omp_target_thread_limit");
4060       break;
4061     case OMPRTL___kmpc_get_hardware_num_blocks:
4062       Changed = Changed | foldKernelFnAttribute(A, "omp_target_num_teams");
4063       break;
4064     default:
4065       llvm_unreachable("Unhandled OpenMP runtime function!");
4066     }
4067 
4068     return Changed;
4069   }
4070 
4071   ChangeStatus manifest(Attributor &A) override {
4072     ChangeStatus Changed = ChangeStatus::UNCHANGED;
4073 
4074     if (SimplifiedValue.hasValue() && SimplifiedValue.getValue()) {
4075       Instruction &I = *getCtxI();
4076       A.changeValueAfterManifest(I, **SimplifiedValue);
4077       A.deleteAfterManifest(I);
4078 
4079       CallBase *CB = dyn_cast<CallBase>(&I);
4080       auto Remark = [&](OptimizationRemark OR) {
4081         if (auto *C = dyn_cast<ConstantInt>(*SimplifiedValue))
4082           return OR << "Replacing OpenMP runtime call "
4083                     << CB->getCalledFunction()->getName() << " with "
4084                     << ore::NV("FoldedValue", C->getZExtValue()) << ".";
4085         else
4086           return OR << "Replacing OpenMP runtime call "
4087                     << CB->getCalledFunction()->getName() << ".";
4088       };
4089 
4090       if (CB && EnableVerboseRemarks)
4091         A.emitRemark<OptimizationRemark>(CB, "OMP180", Remark);
4092 
4093       LLVM_DEBUG(dbgs() << TAG << "Replacing runtime call: " << I << " with "
4094                         << **SimplifiedValue << "\n");
4095 
4096       Changed = ChangeStatus::CHANGED;
4097     }
4098 
4099     return Changed;
4100   }
4101 
4102   ChangeStatus indicatePessimisticFixpoint() override {
4103     SimplifiedValue = nullptr;
4104     return AAFoldRuntimeCall::indicatePessimisticFixpoint();
4105   }
4106 
4107 private:
4108   /// Fold __kmpc_is_spmd_exec_mode into a constant if possible.
4109   ChangeStatus foldIsSPMDExecMode(Attributor &A) {
4110     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4111 
4112     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4113     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4114     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4115         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4116 
4117     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4118       return indicatePessimisticFixpoint();
4119 
4120     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4121       auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4122                                           DepClassTy::REQUIRED);
4123 
4124       if (!AA.isValidState()) {
4125         SimplifiedValue = nullptr;
4126         return indicatePessimisticFixpoint();
4127       }
4128 
4129       if (AA.SPMDCompatibilityTracker.isAssumed()) {
4130         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4131           ++KnownSPMDCount;
4132         else
4133           ++AssumedSPMDCount;
4134       } else {
4135         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4136           ++KnownNonSPMDCount;
4137         else
4138           ++AssumedNonSPMDCount;
4139       }
4140     }
4141 
4142     if ((AssumedSPMDCount + KnownSPMDCount) &&
4143         (AssumedNonSPMDCount + KnownNonSPMDCount))
4144       return indicatePessimisticFixpoint();
4145 
4146     auto &Ctx = getAnchorValue().getContext();
4147     if (KnownSPMDCount || AssumedSPMDCount) {
4148       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4149              "Expected only SPMD kernels!");
4150       // All reaching kernels are in SPMD mode. Update all function calls to
4151       // __kmpc_is_spmd_exec_mode to 1.
4152       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4153     } else if (KnownNonSPMDCount || AssumedNonSPMDCount) {
4154       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4155              "Expected only non-SPMD kernels!");
4156       // All reaching kernels are in non-SPMD mode. Update all function
4157       // calls to __kmpc_is_spmd_exec_mode to 0.
4158       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), false);
4159     } else {
4160       // We have empty reaching kernels, therefore we cannot tell if the
4161       // associated call site can be folded. At this moment, SimplifiedValue
4162       // must be none.
4163       assert(!SimplifiedValue.hasValue() && "SimplifiedValue should be none");
4164     }
4165 
4166     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4167                                                     : ChangeStatus::CHANGED;
4168   }
4169 
4170   /// Fold __kmpc_is_generic_main_thread_id into a constant if possible.
4171   ChangeStatus foldIsGenericMainThread(Attributor &A) {
4172     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4173 
4174     CallBase &CB = cast<CallBase>(getAssociatedValue());
4175     Function *F = CB.getFunction();
4176     const auto &ExecutionDomainAA = A.getAAFor<AAExecutionDomain>(
4177         *this, IRPosition::function(*F), DepClassTy::REQUIRED);
4178 
4179     if (!ExecutionDomainAA.isValidState())
4180       return indicatePessimisticFixpoint();
4181 
4182     auto &Ctx = getAnchorValue().getContext();
4183     if (ExecutionDomainAA.isExecutedByInitialThreadOnly(CB))
4184       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), true);
4185     else
4186       return indicatePessimisticFixpoint();
4187 
4188     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4189                                                     : ChangeStatus::CHANGED;
4190   }
4191 
4192   /// Fold __kmpc_parallel_level into a constant if possible.
4193   ChangeStatus foldParallelLevel(Attributor &A) {
4194     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4195 
4196     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4197         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4198 
4199     if (!CallerKernelInfoAA.ParallelLevels.isValidState())
4200       return indicatePessimisticFixpoint();
4201 
4202     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4203       return indicatePessimisticFixpoint();
4204 
4205     if (CallerKernelInfoAA.ReachingKernelEntries.empty()) {
4206       assert(!SimplifiedValue.hasValue() &&
4207              "SimplifiedValue should keep none at this point");
4208       return ChangeStatus::UNCHANGED;
4209     }
4210 
4211     unsigned AssumedSPMDCount = 0, KnownSPMDCount = 0;
4212     unsigned AssumedNonSPMDCount = 0, KnownNonSPMDCount = 0;
4213     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4214       auto &AA = A.getAAFor<AAKernelInfo>(*this, IRPosition::function(*K),
4215                                           DepClassTy::REQUIRED);
4216       if (!AA.SPMDCompatibilityTracker.isValidState())
4217         return indicatePessimisticFixpoint();
4218 
4219       if (AA.SPMDCompatibilityTracker.isAssumed()) {
4220         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4221           ++KnownSPMDCount;
4222         else
4223           ++AssumedSPMDCount;
4224       } else {
4225         if (AA.SPMDCompatibilityTracker.isAtFixpoint())
4226           ++KnownNonSPMDCount;
4227         else
4228           ++AssumedNonSPMDCount;
4229       }
4230     }
4231 
4232     if ((AssumedSPMDCount + KnownSPMDCount) &&
4233         (AssumedNonSPMDCount + KnownNonSPMDCount))
4234       return indicatePessimisticFixpoint();
4235 
4236     auto &Ctx = getAnchorValue().getContext();
4237     // If the caller can only be reached by SPMD kernel entries, the parallel
4238     // level is 1. Similarly, if the caller can only be reached by non-SPMD
4239     // kernel entries, it is 0.
4240     if (AssumedSPMDCount || KnownSPMDCount) {
4241       assert(KnownNonSPMDCount == 0 && AssumedNonSPMDCount == 0 &&
4242              "Expected only SPMD kernels!");
4243       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 1);
4244     } else {
4245       assert(KnownSPMDCount == 0 && AssumedSPMDCount == 0 &&
4246              "Expected only non-SPMD kernels!");
4247       SimplifiedValue = ConstantInt::get(Type::getInt8Ty(Ctx), 0);
4248     }
4249     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4250                                                     : ChangeStatus::CHANGED;
4251   }
4252 
4253   ChangeStatus foldKernelFnAttribute(Attributor &A, llvm::StringRef Attr) {
4254     // Specialize only if all the calls agree with the attribute constant value
4255     int32_t CurrentAttrValue = -1;
4256     Optional<Value *> SimplifiedValueBefore = SimplifiedValue;
4257 
4258     auto &CallerKernelInfoAA = A.getAAFor<AAKernelInfo>(
4259         *this, IRPosition::function(*getAnchorScope()), DepClassTy::REQUIRED);
4260 
4261     if (!CallerKernelInfoAA.ReachingKernelEntries.isValidState())
4262       return indicatePessimisticFixpoint();
4263 
4264     // Iterate over the kernels that reach this function
4265     for (Kernel K : CallerKernelInfoAA.ReachingKernelEntries) {
4266       int32_t NextAttrVal = -1;
4267       if (K->hasFnAttribute(Attr))
4268         NextAttrVal =
4269             std::stoi(K->getFnAttribute(Attr).getValueAsString().str());
4270 
4271       if (NextAttrVal == -1 ||
4272           (CurrentAttrValue != -1 && CurrentAttrValue != NextAttrVal))
4273         return indicatePessimisticFixpoint();
4274       CurrentAttrValue = NextAttrVal;
4275     }
4276 
4277     if (CurrentAttrValue != -1) {
4278       auto &Ctx = getAnchorValue().getContext();
4279       SimplifiedValue =
4280           ConstantInt::get(Type::getInt32Ty(Ctx), CurrentAttrValue);
4281     }
4282     return SimplifiedValue == SimplifiedValueBefore ? ChangeStatus::UNCHANGED
4283                                                     : ChangeStatus::CHANGED;
4284   }
4285 
4286   /// An optional value the associated value is assumed to fold to. That is, we
4287   /// assume the associated value (which is a call) can be replaced by this
4288   /// simplified value.
4289   Optional<Value *> SimplifiedValue;
4290 
4291   /// The runtime function kind of the callee of the associated call site.
4292   RuntimeFunction RFKind;
4293 };
4294 
4295 } // namespace
4296 
4297 /// Register folding callsite
4298 void OpenMPOpt::registerFoldRuntimeCall(RuntimeFunction RF) {
4299   auto &RFI = OMPInfoCache.RFIs[RF];
4300   RFI.foreachUse(SCC, [&](Use &U, Function &F) {
4301     CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &RFI);
4302     if (!CI)
4303       return false;
4304     A.getOrCreateAAFor<AAFoldRuntimeCall>(
4305         IRPosition::callsite_returned(*CI), /* QueryingAA */ nullptr,
4306         DepClassTy::NONE, /* ForceUpdate */ false,
4307         /* UpdateAfterInit */ false);
4308     return false;
4309   });
4310 }
4311 
4312 void OpenMPOpt::registerAAs(bool IsModulePass) {
4313   if (SCC.empty())
4314 
4315     return;
4316   if (IsModulePass) {
4317     // Ensure we create the AAKernelInfo AAs first and without triggering an
4318     // update. This will make sure we register all value simplification
4319     // callbacks before any other AA has the chance to create an AAValueSimplify
4320     // or similar.
4321     for (Function *Kernel : OMPInfoCache.Kernels)
4322       A.getOrCreateAAFor<AAKernelInfo>(
4323           IRPosition::function(*Kernel), /* QueryingAA */ nullptr,
4324           DepClassTy::NONE, /* ForceUpdate */ false,
4325           /* UpdateAfterInit */ false);
4326 
4327     registerFoldRuntimeCall(OMPRTL___kmpc_is_generic_main_thread_id);
4328     registerFoldRuntimeCall(OMPRTL___kmpc_is_spmd_exec_mode);
4329     registerFoldRuntimeCall(OMPRTL___kmpc_parallel_level);
4330     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_threads_in_block);
4331     registerFoldRuntimeCall(OMPRTL___kmpc_get_hardware_num_blocks);
4332   }
4333 
4334   // Create CallSite AA for all Getters.
4335   for (int Idx = 0; Idx < OMPInfoCache.ICVs.size() - 1; ++Idx) {
4336     auto ICVInfo = OMPInfoCache.ICVs[static_cast<InternalControlVar>(Idx)];
4337 
4338     auto &GetterRFI = OMPInfoCache.RFIs[ICVInfo.Getter];
4339 
4340     auto CreateAA = [&](Use &U, Function &Caller) {
4341       CallInst *CI = OpenMPOpt::getCallIfRegularCall(U, &GetterRFI);
4342       if (!CI)
4343         return false;
4344 
4345       auto &CB = cast<CallBase>(*CI);
4346 
4347       IRPosition CBPos = IRPosition::callsite_function(CB);
4348       A.getOrCreateAAFor<AAICVTracker>(CBPos);
4349       return false;
4350     };
4351 
4352     GetterRFI.foreachUse(SCC, CreateAA);
4353   }
4354   auto &GlobalizationRFI = OMPInfoCache.RFIs[OMPRTL___kmpc_alloc_shared];
4355   auto CreateAA = [&](Use &U, Function &F) {
4356     A.getOrCreateAAFor<AAHeapToShared>(IRPosition::function(F));
4357     return false;
4358   };
4359   if (!DisableOpenMPOptDeglobalization)
4360     GlobalizationRFI.foreachUse(SCC, CreateAA);
4361 
4362   // Create an ExecutionDomain AA for every function and a HeapToStack AA for
4363   // every function if there is a device kernel.
4364   if (!isOpenMPDevice(M))
4365     return;
4366 
4367   for (auto *F : SCC) {
4368     if (F->isDeclaration())
4369       continue;
4370 
4371     A.getOrCreateAAFor<AAExecutionDomain>(IRPosition::function(*F));
4372     if (!DisableOpenMPOptDeglobalization)
4373       A.getOrCreateAAFor<AAHeapToStack>(IRPosition::function(*F));
4374 
4375     for (auto &I : instructions(*F)) {
4376       if (auto *LI = dyn_cast<LoadInst>(&I)) {
4377         bool UsedAssumedInformation = false;
4378         A.getAssumedSimplified(IRPosition::value(*LI), /* AA */ nullptr,
4379                                UsedAssumedInformation);
4380       }
4381     }
4382   }
4383 }
4384 
4385 const char AAICVTracker::ID = 0;
4386 const char AAKernelInfo::ID = 0;
4387 const char AAExecutionDomain::ID = 0;
4388 const char AAHeapToShared::ID = 0;
4389 const char AAFoldRuntimeCall::ID = 0;
4390 
4391 AAICVTracker &AAICVTracker::createForPosition(const IRPosition &IRP,
4392                                               Attributor &A) {
4393   AAICVTracker *AA = nullptr;
4394   switch (IRP.getPositionKind()) {
4395   case IRPosition::IRP_INVALID:
4396   case IRPosition::IRP_FLOAT:
4397   case IRPosition::IRP_ARGUMENT:
4398   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4399     llvm_unreachable("ICVTracker can only be created for function position!");
4400   case IRPosition::IRP_RETURNED:
4401     AA = new (A.Allocator) AAICVTrackerFunctionReturned(IRP, A);
4402     break;
4403   case IRPosition::IRP_CALL_SITE_RETURNED:
4404     AA = new (A.Allocator) AAICVTrackerCallSiteReturned(IRP, A);
4405     break;
4406   case IRPosition::IRP_CALL_SITE:
4407     AA = new (A.Allocator) AAICVTrackerCallSite(IRP, A);
4408     break;
4409   case IRPosition::IRP_FUNCTION:
4410     AA = new (A.Allocator) AAICVTrackerFunction(IRP, A);
4411     break;
4412   }
4413 
4414   return *AA;
4415 }
4416 
4417 AAExecutionDomain &AAExecutionDomain::createForPosition(const IRPosition &IRP,
4418                                                         Attributor &A) {
4419   AAExecutionDomainFunction *AA = nullptr;
4420   switch (IRP.getPositionKind()) {
4421   case IRPosition::IRP_INVALID:
4422   case IRPosition::IRP_FLOAT:
4423   case IRPosition::IRP_ARGUMENT:
4424   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4425   case IRPosition::IRP_RETURNED:
4426   case IRPosition::IRP_CALL_SITE_RETURNED:
4427   case IRPosition::IRP_CALL_SITE:
4428     llvm_unreachable(
4429         "AAExecutionDomain can only be created for function position!");
4430   case IRPosition::IRP_FUNCTION:
4431     AA = new (A.Allocator) AAExecutionDomainFunction(IRP, A);
4432     break;
4433   }
4434 
4435   return *AA;
4436 }
4437 
4438 AAHeapToShared &AAHeapToShared::createForPosition(const IRPosition &IRP,
4439                                                   Attributor &A) {
4440   AAHeapToSharedFunction *AA = nullptr;
4441   switch (IRP.getPositionKind()) {
4442   case IRPosition::IRP_INVALID:
4443   case IRPosition::IRP_FLOAT:
4444   case IRPosition::IRP_ARGUMENT:
4445   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4446   case IRPosition::IRP_RETURNED:
4447   case IRPosition::IRP_CALL_SITE_RETURNED:
4448   case IRPosition::IRP_CALL_SITE:
4449     llvm_unreachable(
4450         "AAHeapToShared can only be created for function position!");
4451   case IRPosition::IRP_FUNCTION:
4452     AA = new (A.Allocator) AAHeapToSharedFunction(IRP, A);
4453     break;
4454   }
4455 
4456   return *AA;
4457 }
4458 
4459 AAKernelInfo &AAKernelInfo::createForPosition(const IRPosition &IRP,
4460                                               Attributor &A) {
4461   AAKernelInfo *AA = nullptr;
4462   switch (IRP.getPositionKind()) {
4463   case IRPosition::IRP_INVALID:
4464   case IRPosition::IRP_FLOAT:
4465   case IRPosition::IRP_ARGUMENT:
4466   case IRPosition::IRP_RETURNED:
4467   case IRPosition::IRP_CALL_SITE_RETURNED:
4468   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4469     llvm_unreachable("KernelInfo can only be created for function position!");
4470   case IRPosition::IRP_CALL_SITE:
4471     AA = new (A.Allocator) AAKernelInfoCallSite(IRP, A);
4472     break;
4473   case IRPosition::IRP_FUNCTION:
4474     AA = new (A.Allocator) AAKernelInfoFunction(IRP, A);
4475     break;
4476   }
4477 
4478   return *AA;
4479 }
4480 
4481 AAFoldRuntimeCall &AAFoldRuntimeCall::createForPosition(const IRPosition &IRP,
4482                                                         Attributor &A) {
4483   AAFoldRuntimeCall *AA = nullptr;
4484   switch (IRP.getPositionKind()) {
4485   case IRPosition::IRP_INVALID:
4486   case IRPosition::IRP_FLOAT:
4487   case IRPosition::IRP_ARGUMENT:
4488   case IRPosition::IRP_RETURNED:
4489   case IRPosition::IRP_FUNCTION:
4490   case IRPosition::IRP_CALL_SITE:
4491   case IRPosition::IRP_CALL_SITE_ARGUMENT:
4492     llvm_unreachable("KernelInfo can only be created for call site position!");
4493   case IRPosition::IRP_CALL_SITE_RETURNED:
4494     AA = new (A.Allocator) AAFoldRuntimeCallCallSiteReturned(IRP, A);
4495     break;
4496   }
4497 
4498   return *AA;
4499 }
4500 
4501 PreservedAnalyses OpenMPOptPass::run(Module &M, ModuleAnalysisManager &AM) {
4502   if (!containsOpenMP(M))
4503     return PreservedAnalyses::all();
4504   if (DisableOpenMPOptimizations)
4505     return PreservedAnalyses::all();
4506 
4507   FunctionAnalysisManager &FAM =
4508       AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
4509   KernelSet Kernels = getDeviceKernels(M);
4510 
4511   auto IsCalled = [&](Function &F) {
4512     if (Kernels.contains(&F))
4513       return true;
4514     for (const User *U : F.users())
4515       if (!isa<BlockAddress>(U))
4516         return true;
4517     return false;
4518   };
4519 
4520   auto EmitRemark = [&](Function &F) {
4521     auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F);
4522     ORE.emit([&]() {
4523       OptimizationRemarkAnalysis ORA(DEBUG_TYPE, "OMP140", &F);
4524       return ORA << "Could not internalize function. "
4525                  << "Some optimizations may not be possible. [OMP140]";
4526     });
4527   };
4528 
4529   // Create internal copies of each function if this is a kernel Module. This
4530   // allows iterprocedural passes to see every call edge.
4531   DenseMap<Function *, Function *> InternalizedMap;
4532   if (isOpenMPDevice(M)) {
4533     SmallPtrSet<Function *, 16> InternalizeFns;
4534     for (Function &F : M)
4535       if (!F.isDeclaration() && !Kernels.contains(&F) && IsCalled(F) &&
4536           !DisableInternalization) {
4537         if (Attributor::isInternalizable(F)) {
4538           InternalizeFns.insert(&F);
4539         } else if (!F.hasLocalLinkage() && !F.hasFnAttribute(Attribute::Cold)) {
4540           EmitRemark(F);
4541         }
4542       }
4543 
4544     Attributor::internalizeFunctions(InternalizeFns, InternalizedMap);
4545   }
4546 
4547   // Look at every function in the Module unless it was internalized.
4548   SmallVector<Function *, 16> SCC;
4549   for (Function &F : M)
4550     if (!F.isDeclaration() && !InternalizedMap.lookup(&F))
4551       SCC.push_back(&F);
4552 
4553   if (SCC.empty())
4554     return PreservedAnalyses::all();
4555 
4556   AnalysisGetter AG(FAM);
4557 
4558   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
4559     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
4560   };
4561 
4562   BumpPtrAllocator Allocator;
4563   CallGraphUpdater CGUpdater;
4564 
4565   SetVector<Function *> Functions(SCC.begin(), SCC.end());
4566   OMPInformationCache InfoCache(M, AG, Allocator, /*CGSCC*/ Functions, Kernels);
4567 
4568   unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4569   Attributor A(Functions, InfoCache, CGUpdater, nullptr, true, false,
4570                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4571 
4572   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4573   bool Changed = OMPOpt.run(true);
4574 
4575   // Optionally inline device functions for potentially better performance.
4576   if (AlwaysInlineDeviceFunctions && isOpenMPDevice(M))
4577     for (Function &F : M)
4578       if (!F.isDeclaration() && !Kernels.contains(&F) &&
4579           !F.hasFnAttribute(Attribute::NoInline))
4580         F.addFnAttr(Attribute::AlwaysInline);
4581 
4582   if (PrintModuleAfterOptimizations)
4583     LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt Module Pass:\n" << M);
4584 
4585   if (Changed)
4586     return PreservedAnalyses::none();
4587 
4588   return PreservedAnalyses::all();
4589 }
4590 
4591 PreservedAnalyses OpenMPOptCGSCCPass::run(LazyCallGraph::SCC &C,
4592                                           CGSCCAnalysisManager &AM,
4593                                           LazyCallGraph &CG,
4594                                           CGSCCUpdateResult &UR) {
4595   if (!containsOpenMP(*C.begin()->getFunction().getParent()))
4596     return PreservedAnalyses::all();
4597   if (DisableOpenMPOptimizations)
4598     return PreservedAnalyses::all();
4599 
4600   SmallVector<Function *, 16> SCC;
4601   // If there are kernels in the module, we have to run on all SCC's.
4602   for (LazyCallGraph::Node &N : C) {
4603     Function *Fn = &N.getFunction();
4604     SCC.push_back(Fn);
4605   }
4606 
4607   if (SCC.empty())
4608     return PreservedAnalyses::all();
4609 
4610   Module &M = *C.begin()->getFunction().getParent();
4611 
4612   KernelSet Kernels = getDeviceKernels(M);
4613 
4614   FunctionAnalysisManager &FAM =
4615       AM.getResult<FunctionAnalysisManagerCGSCCProxy>(C, CG).getManager();
4616 
4617   AnalysisGetter AG(FAM);
4618 
4619   auto OREGetter = [&FAM](Function *F) -> OptimizationRemarkEmitter & {
4620     return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F);
4621   };
4622 
4623   BumpPtrAllocator Allocator;
4624   CallGraphUpdater CGUpdater;
4625   CGUpdater.initialize(CG, C, AM, UR);
4626 
4627   SetVector<Function *> Functions(SCC.begin(), SCC.end());
4628   OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG, Allocator,
4629                                 /*CGSCC*/ Functions, Kernels);
4630 
4631   unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4632   Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
4633                MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4634 
4635   OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4636   bool Changed = OMPOpt.run(false);
4637 
4638   if (PrintModuleAfterOptimizations)
4639     LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
4640 
4641   if (Changed)
4642     return PreservedAnalyses::none();
4643 
4644   return PreservedAnalyses::all();
4645 }
4646 
4647 namespace {
4648 
4649 struct OpenMPOptCGSCCLegacyPass : public CallGraphSCCPass {
4650   CallGraphUpdater CGUpdater;
4651   static char ID;
4652 
4653   OpenMPOptCGSCCLegacyPass() : CallGraphSCCPass(ID) {
4654     initializeOpenMPOptCGSCCLegacyPassPass(*PassRegistry::getPassRegistry());
4655   }
4656 
4657   void getAnalysisUsage(AnalysisUsage &AU) const override {
4658     CallGraphSCCPass::getAnalysisUsage(AU);
4659   }
4660 
4661   bool runOnSCC(CallGraphSCC &CGSCC) override {
4662     if (!containsOpenMP(CGSCC.getCallGraph().getModule()))
4663       return false;
4664     if (DisableOpenMPOptimizations || skipSCC(CGSCC))
4665       return false;
4666 
4667     SmallVector<Function *, 16> SCC;
4668     // If there are kernels in the module, we have to run on all SCC's.
4669     for (CallGraphNode *CGN : CGSCC) {
4670       Function *Fn = CGN->getFunction();
4671       if (!Fn || Fn->isDeclaration())
4672         continue;
4673       SCC.push_back(Fn);
4674     }
4675 
4676     if (SCC.empty())
4677       return false;
4678 
4679     Module &M = CGSCC.getCallGraph().getModule();
4680     KernelSet Kernels = getDeviceKernels(M);
4681 
4682     CallGraph &CG = getAnalysis<CallGraphWrapperPass>().getCallGraph();
4683     CGUpdater.initialize(CG, CGSCC);
4684 
4685     // Maintain a map of functions to avoid rebuilding the ORE
4686     DenseMap<Function *, std::unique_ptr<OptimizationRemarkEmitter>> OREMap;
4687     auto OREGetter = [&OREMap](Function *F) -> OptimizationRemarkEmitter & {
4688       std::unique_ptr<OptimizationRemarkEmitter> &ORE = OREMap[F];
4689       if (!ORE)
4690         ORE = std::make_unique<OptimizationRemarkEmitter>(F);
4691       return *ORE;
4692     };
4693 
4694     AnalysisGetter AG;
4695     SetVector<Function *> Functions(SCC.begin(), SCC.end());
4696     BumpPtrAllocator Allocator;
4697     OMPInformationCache InfoCache(*(Functions.back()->getParent()), AG,
4698                                   Allocator,
4699                                   /*CGSCC*/ Functions, Kernels);
4700 
4701     unsigned MaxFixpointIterations = (isOpenMPDevice(M)) ? 128 : 32;
4702     Attributor A(Functions, InfoCache, CGUpdater, nullptr, false, true,
4703                  MaxFixpointIterations, OREGetter, DEBUG_TYPE);
4704 
4705     OpenMPOpt OMPOpt(SCC, CGUpdater, OREGetter, InfoCache, A);
4706     bool Result = OMPOpt.run(false);
4707 
4708     if (PrintModuleAfterOptimizations)
4709       LLVM_DEBUG(dbgs() << TAG << "Module after OpenMPOpt CGSCC Pass:\n" << M);
4710 
4711     return Result;
4712   }
4713 
4714   bool doFinalization(CallGraph &CG) override { return CGUpdater.finalize(); }
4715 };
4716 
4717 } // end anonymous namespace
4718 
4719 KernelSet llvm::omp::getDeviceKernels(Module &M) {
4720   // TODO: Create a more cross-platform way of determining device kernels.
4721   NamedMDNode *MD = M.getOrInsertNamedMetadata("nvvm.annotations");
4722   KernelSet Kernels;
4723 
4724   if (!MD)
4725     return Kernels;
4726 
4727   for (auto *Op : MD->operands()) {
4728     if (Op->getNumOperands() < 2)
4729       continue;
4730     MDString *KindID = dyn_cast<MDString>(Op->getOperand(1));
4731     if (!KindID || KindID->getString() != "kernel")
4732       continue;
4733 
4734     Function *KernelFn =
4735         mdconst::dyn_extract_or_null<Function>(Op->getOperand(0));
4736     if (!KernelFn)
4737       continue;
4738 
4739     ++NumOpenMPTargetRegionKernels;
4740 
4741     Kernels.insert(KernelFn);
4742   }
4743 
4744   return Kernels;
4745 }
4746 
4747 bool llvm::omp::containsOpenMP(Module &M) {
4748   Metadata *MD = M.getModuleFlag("openmp");
4749   if (!MD)
4750     return false;
4751 
4752   return true;
4753 }
4754 
4755 bool llvm::omp::isOpenMPDevice(Module &M) {
4756   Metadata *MD = M.getModuleFlag("openmp-device");
4757   if (!MD)
4758     return false;
4759 
4760   return true;
4761 }
4762 
4763 char OpenMPOptCGSCCLegacyPass::ID = 0;
4764 
4765 INITIALIZE_PASS_BEGIN(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
4766                       "OpenMP specific optimizations", false, false)
4767 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass)
4768 INITIALIZE_PASS_END(OpenMPOptCGSCCLegacyPass, "openmp-opt-cgscc",
4769                     "OpenMP specific optimizations", false, false)
4770 
4771 Pass *llvm::createOpenMPOptCGSCCLegacyPass() {
4772   return new OpenMPOptCGSCCLegacyPass();
4773 }
4774